理论上,“Attention is All You Need”。但在实践中,我们还需要像 FlashAttention 这样经过优化的注意力实现。
尽管这些融合的注意力实现显著提升了性能并支持长上下文,但这种效率是以牺牲灵活性为代价的。你无法再通过编写几个 PyTorch 算子来尝试新的注意力变体——通常你需要编写一个新的自定义核函数!这对于机器学习研究人员来说就像是一种“软件彩票”——如果你的注意力变体不适合现有的优化核函数,你就会面临缓慢的运行时和 CUDA 内存不足(OOM)的问题。
例如注意力变体有:因果注意力(Causal)、相对位置编码(Relative Positional Embeddings)、ALiBi、滑动窗口注意力(Sliding Window Attention)、PrefixLM、文档掩码/样本打包/不规则张量(Document Masking/Sample Packing/Jagged Tensors)、Tanh 软限幅(Tanh Soft-Capping)、分页注意力(PagedAttention)等。更糟糕的是,人们通常想要这些变体的组合!例如,滑动窗口注意力 + 文档掩码 + 因果注意力 + 上下文并行?或者分页注意力 + 滑动窗口 + Tanh 软限幅怎么样?
下图左侧的图片(原文中可能包含)展示了目前的现状——一些掩码 + 偏置 + 设置的组合已有现成的核函数实现。但各种选项导致设置数量呈指数级增长,因此总体支持非常零星。更糟糕的是,研究人员提出的新注意力变体将完全没有支持。
为了一劳永逸地解决这个超立方体问题,我们引入了 FlexAttention,这是一个新的 PyTorch API。
- 我们提供了一个灵活的 API,只需几行地道的 PyTorch 代码即可实现许多注意力变体(包括目前博文中提到的所有变体)。
- 我们通过
torch.compile
将其编译为一个融合的 FlashAttention 核函数,生成一个不占用额外内存且性能可与手写核函数媲美的 FlashAttention 核函数。 - 我们还利用 PyTorch 的 autograd 机制自动生成反向传播。
- 最后,我们还可以利用注意力掩码中的稀疏性,从而比标准注意力实现获得显著提升。
有了 FlexAttention,我们希望尝试新的注意力变体只受限于你的想象力。
你可以在 Attention Gym 找到许多 FlexAttention 示例:https://github.com/pytorch-labs/attention-gym。如果你有任何很酷的应用,欢迎提交示例!
PS:我们还发现这个 API 非常令人兴奋,因为它以一种有趣的方式利用了许多现有的 PyTorch 基础设施——更多内容将在最后介绍。
FlexAttention
以下是经典的注意力公式
代码形式如下
Q, K, V: Tensor[batch_size, num_heads, sequence_length, head_dim]
score: Tensor[batch_size, num_heads, sequence_length, sequence_length] = (Q @ K) / sqrt(head_dim)
probabilities = softmax(score, dim=-1)
output: Tensor[batch_size, num_heads, sequence_length, head_dim] = probabilities @ V
FlexAttention 允许使用用户自定义函数 score_mod:
代码形式如下
Q, K, V: Tensor[batch_size, num_heads, sequence_length, head_dim]
score: Tensor[batch_size, num_heads, sequence_length, sequence_length] = (Q @ K) / sqrt(head_dim)
modified_scores: Tensor[batch_size, num_heads, sequence_length, sequence_length] = score_mod(score)
probabilities = softmax(modified_scores, dim=-1)
output: Tensor[batch_size, num_heads, sequence_length, head_dim] = probabilities @ V
此函数允许你在 softmax 计算之前 *修改* 注意力分数。令人惊讶的是,这足以实现绝大多数注意力变体(示例如下)!
具体来说,score_mod
的预期函数签名有些特别。
def score_mod(score: f32[], b: i32[], h: i32[], q_idx: i32[], kv_idx: i32[])
return score # noop - standard attention
换句话说,score
是一个标量 PyTorch 张量,表示查询 token 和键 token 的点积。其余参数告诉你当前正在计算 *哪个* 点积——b
(当前批次元素)、h
(当前注意力头)、q_idx
(查询中的位置)、kv_idx
(键/值张量中的位置)。
要应用此函数,可以将其实现为
for b in range(batch_size):
for h in range(num_heads):
for q_idx in range(sequence_length):
for kv_idx in range(sequence_length):
modified_scores[b, h, q_idx, kv_idx] = score_mod(scores[b, h, q_idx, kv_idx], b, h, q_idx, kv_idx)
当然,这并非 FlexAttention 的底层实现方式。通过利用 torch.compile
,我们自动将你的函数编译成单个*融合*的 FlexAttention 核函数——保证有效,否则退款!
这个 API 最终出人意料地具有表现力。让我们看一些示例。
score_mod 示例
完全注意力
首先来看“完全注意力”,即标准的双向注意力。在这种情况下,score_mod
是一个无操作函数——它接收分数作为输入,然后原样返回。
def noop(score, b, h, q_idx, kv_idx):
return score
及其端到端(包括前向 *和* 反向)使用方法
from torch.nn.attention.flex_attention import flex_attention
flex_attention(query, key, value, score_mod=noop).sum().backward()
相对位置编码
一种常见的注意力变体是“相对位置编码”。相对位置编码不编码绝对距离,而是根据查询和键之间的“距离”调整分数。
def relative_positional(score, b, h, q_idx, kv_idx):
return score + (q_idx - kv_idx)
注意,与典型实现不同,这*不需要*实例化一个 SxS 张量。相反,FlexAttention 会在核函数内部“即时”计算偏置值,从而显著改善内存和性能。
ALiBi 偏置
来源:Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation
ALiBi 在《Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation》中被提出,并声称在推理时具有有利于长度外推的特性。值得注意的是,MosaicML 曾指出“缺乏核函数支持”是他们最终从 ALiBi 转向旋转位置编码(rotary embeddings)的主要原因。
Alibi 与相对位置编码类似,但有一个例外——它有一个通常预计算的每个头部的因子。
alibi_bias = generate_alibi_bias() # [num_heads]
def alibi(score, b, h, q_idx, kv_idx):
bias = alibi_bias[h] * (kv_idx - q_idx)
return score + bias
这展示了 torch.compile
提供的一个有趣的灵活性——即使 alibi_bias
*未明确作为输入传入*,我们也可以从中加载!生成的 Triton 核函数将计算从 alibi_bias
张量中的正确加载并将其融合。请注意,即使你重新生成 alibi_bias
,我们仍然不需要重新编译。
软限幅
软限幅是一种在 Gemma2 和 Grok-1 中使用的技术,它防止 logits 增长过大。在 FlexAttention 中,它看起来像
softcap = 20
def soft_cap(score, b, h, q_idx, kv_idx):
score = score / softcap
score = torch.tanh(score)
score = score * softcap
return score
注意,这里我们也自动从前向传播生成了反向传播。此外,尽管这个实现是语义正确的,但考虑到性能原因,我们可能希望在这种情况下使用 tanh 近似。更多详情请参见 attention-gym。
因果掩码
尽管双向注意力是最简单的,但最初的《Attention is All You Need》论文和绝大多数大型语言模型(LLMs)在解码器模式下使用注意力,其中每个 token 只能关注其之前的 token。人们通常将其视为一个下三角掩码,但使用 score_mod
API,它可以表达为
def causal_mask(score, b, h, q_idx, kv_idx):
return torch.where(q_idx >= kv_idx, score, -float("inf"))
基本上,如果查询 token “在”键 token “之后”,我们保留分数。否则,我们将其掩码掉,方法是将其设置为 -inf,从而确保它不参与 softmax 计算。
然而,与其它修改相比,掩码是特殊的——如果某些内容被掩码掉,我们可以完全跳过其计算!在这种情况下,因果掩码具有约 50% 的稀疏性,因此不利用这种稀疏性将导致速度慢 2 倍。尽管这个 score_mod
足以*正确*实现因果掩码,但要获得稀疏性的性能优势,还需要另一个概念——mask_mod
。
mask_mod
为了利用掩码带来的稀疏性,我们需要做更多的工作。具体来说,通过将 mask_mod
传递给 create_block_mask
,我们可以创建一个 BlockMask
。FlexAttention 随后就可以使用 BlockMask
来利用稀疏性!
注意,mask_mod
的签名与 score_mod
非常相似——只是没有 score
参数。具体来说
# returns True if this position should participate in the computation
mask_mod(b, h, q_idx, kv_idx) => bool
注意,score_mod
在表达能力上严格*强于* mask_mod
。然而,对于掩码操作,建议使用 mask_mod
和 create_block_mask
,因为这样性能更好。关于为何 score_mod
和 mask_mod
是分开的,请参阅 FAQ 部分。
现在,让我们看看如何使用 mask_mod
实现因果掩码。
因果掩码
from torch.nn.attention.flex_attention import create_block_mask
def causal(b, h, q_idx, kv_idx):
return q_idx >= kv_idx
# Because the sparsity pattern is independent of batch and heads, we'll set them to None (which broadcasts them)
block_mask = create_block_mask(causal, B=None, H=None, Q_LEN=1024, KV_LEN=1024)
# In this case, we don't need a score_mod, so we won't pass any in.
# However, score_mod can still be combined with block_mask if you need the additional flexibility.
flex_attention(query, key, value, block_mask=block_mask)
请注意,create_block_mask
是一个相对昂贵的操作!尽管 FlexAttention 在它改变时不需要重新编译,但如果你不注意缓存它,它可能会导致显著的性能下降(请查看 FAQ 部分获取最佳实践建议)。
TFlops 大致相同,但使用 mask_mod 版本的执行时间快了 2 倍!这表明我们可以利用 BlockMask 提供的稀疏性,*而不会*损失硬件效率。
滑动窗口 + 因果注意力
来源:Mistral 7B
Mistral 推出的滑动窗口注意力(也称为局部注意力)利用了最近的 token 最有用的直觉。具体来说,它允许查询 token 只关注,例如,最近的 1024 个 token。这通常与因果注意力一起使用。
SLIDING_WINDOW = 1024
def sliding_window_causal(b, h, q_idx, kv_idx):
causal_mask = q_idx >= kv_idx
window_mask = q_idx - kv_idx <= SLIDING_WINDOW
return causal_mask & window_mask
# If you want to be cute...
from torch.nn.attention import and_masks
def sliding_window(b, h, q_idx, kv_idx)
return q_idx - kv_idx <= SLIDING_WINDOW
sliding_window_causal = and_masks(causal_mask, sliding_window)
我们将其与带有滑动窗口掩码的 F.scaled_dot_product_attention
以及带有因果掩码的 FA2(作为性能参考点)进行了基准测试。我们不仅比 F.scaled_dot_product_attention
快得多,而且比带有因果掩码的 FA2 *也*快得多,因为此掩码具有显著更高的稀疏性。
PrefixLM
来源:PaliGemma: A versatile 3B VLM for transfer
《Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer》中提出的 T5 架构描述了一种注意力变体,它在“前缀”上执行完全双向注意力,在其余部分执行因果注意力。我们再次组合两个掩码函数来实现这一点,一个用于因果掩码,另一个基于前缀长度。
prefix_length: [B]
def prefix_mask(b, h, q_idx, kv_idx):
return kv_idx <= prefix_length[b]
prefix_lm_causal = or_masks(prefix_mask, causal_mask)
# In this case, our mask is different per sequence so we set B equal to our batch size
block_mask = create_block_mask(prefix_lm_causal, B=B, H=None, S, S)
就像 score_mod
一样,mask_mod
允许我们引用未明确作为输入传入的额外张量!然而,对于 PrefixLM,稀疏性模式会随*每个* *输入*而变化。这意味着对于每个新的输入批次,我们将需要重新计算 BlockMask
。一种常见模式是在模型开始时调用 create_block_mask
并将该 block_mask
重用于模型中的所有注意力调用。请参阅 *重新计算 Block Mask 与重新编译* 部分。
然而,作为交换,我们不仅能够获得用于 PrefixLM 的高效注意力核函数,*还*能够利用输入中存在的任意稀疏性。FlexAttention 将根据 BlockMask 数据动态调整其性能,*无需*重新编译核函数。
文档掩码/不规则序列
另一种常见的注意力变体是文档掩码/不规则序列。想象你有多个长度不同的序列。你想把它们放在一起进行训练,但不幸的是,大多数算子只接受矩形张量。
通过 BlockMask
,我们也可以在 FlexAttention 中高效地支持这一点!
- 首先,我们将所有序列展平为一个包含 sum(序列长度) 个 token 的单个序列。
- 然后,我们计算每个 token 所属的 document_id。
- 最后,在我们的
mask_mod
中,我们只需判断查询 token 和 kv token 是否属于同一文档!
# The document that each token belongs to.
# e.g. [0, 0, 0, 1, 1, 2, 2, 2, 2, 2, 2] corresponds to sequence lengths 3, 2, and 6.
document_id: [SEQ_LEN]
def document_masking(b, h, q_idx, kv_idx):
return document_id[q_idx] == document_id[kv_idx]
就是这样!在这种情况下,我们看到最终得到了一个分块对角掩码(blockdiagonal mask)。
文档掩码的一个有趣之处在于,很容易看出它如何与任意组合的其他掩码结合。例如,我们已经在上一节中定义了 prefixlm_mask
。那么现在我们还需要定义一个 prefixlm_document_mask
函数吗?
在这些情况下,我们发现一种非常有用的模式,我们称之为“更高级别修改”。在这种情况下,我们可以获取一个现有的 mask_mod
,并自动将其转换为适用于不规则序列的 mask_mod
!
def generate_doc_mask_mod(mask_mod, document_id):
# Get unique document IDs and their counts
_, counts = torch.unique_consecutive(document_id, return_counts=True)
# Create cumulative counts (offsets)
offsets = torch.cat([torch.tensor([0], device=document_id.device), counts.cumsum(0)[:-1]])
def doc_mask_wrapper(b, h, q_idx, kv_idx):
same_doc = document_id[q_idx] == document_id[kv_idx]
q_logical = q_idx - offsets[document_id[q_idx]]
kv_logical = kv_idx - offsets[document_id[kv_idx]]
inner_mask = mask_mod(b, h, q_logical, kv_logical)
return same_doc & inner_mask
return doc_mask_wrapper
例如,给定上面提到的 prefix_lm_causal
掩码,我们可以将其转换为适用于打包文档的掩码,如下所示:
prefix_length = torch.tensor(2, dtype=torch.int32, device="cuda")
def prefix_mask(b, h, q_idx, kv_idx):
return kv_idx < prefix_length
prefix_lm_causal = or_masks(prefix_mask, causal_mask)
doc_prefix_lm_causal_mask = generate_doc_mask_mod(prefix_lm_causal, document_id)
现在,这个掩码形状是“分块-PrefixLM-对角”的。 :)
以上就是我们的所有示例!还有比我们列出的多得多的注意力变体,请查看 Attention Gym 以获取更多示例。我们希望社区也能贡献他们最喜欢的 FlexAttention 应用。
常见问题
问:FlexAttention 何时需要重新编译?
答:由于 FlexAttention 利用 torch.compile
进行图捕获,它实际上可以在许多情况下避免重新编译。值得注意的是,即使捕获的张量值发生变化,它也*不需要*重新编译!
flex_attention = torch.compile(flex_attention)
def create_bias_mod(bias)
def bias_mod(score, b, h, q_idx, kv_idx):
return score + bias
return bias_mod
bias_mod1 = create_bias_mod(torch.tensor(0))
flex_attention(..., score_mod=bias_mod1) # Compiles the kernel here
bias_mod2 = create_bias_mod(torch.tensor(2))
flex_attention(..., score_mod=bias_mod2) # Doesn't need to recompile!
即使更改分块稀疏性(block-sparsity)也不需要重新编译。但是,如果分块稀疏性发生变化,我们确实需要*重新计算* BlockMask。
问:我们何时应该重新计算 BlockMask?
答:当分块稀疏性发生变化时,我们需要重新计算 BlockMask。尽管计算 BlockMask 比重新编译便宜得多(大约是几百微秒而不是几秒),但你仍然应该注意不要过度频繁地重新计算 BlockMask。
以下是一些常见模式和建议。
掩码永不改变(例如:因果掩码)
在这种情况下,你可以简单地预计算并全局缓存 BlockMask,并在所有注意力调用中重用它。
block_mask = create_block_mask(causal_mask, 1, 1, S,S)
causal_attention = functools.partial(flex_attention, block_mask=block_mask)
掩码随每个批次变化(例如:文档掩码)
在这种情况下,我们建议在模型开始时计算 BlockMask 并将其贯穿整个模型——在所有层中重用该 BlockMask。
def forward(self, x, doc_mask):
# Compute block mask at beginning of forwards
block_mask = create_block_mask(doc_mask, None, None, S, S)
x = self.layer1(x, block_mask)
x = self.layer2(x, block_mask)
...
# amortize block mask construction cost across all layers
x = self.layer3(x, block_mask)
return x
掩码随每层变化(例如:数据依赖的稀疏性)
这是最难的情况,因为我们无法在多次调用 FlexAttention 之间分摊 Block Mask 的计算开销。尽管 FlexAttention 当然仍然可以使这种情况受益,但 BlockMask 的实际收益取决于你的注意力掩码有多稀疏以及我们构建 BlockMask 的速度有多快。这引出了下一个问题……
问:如何更快地计算 BlockMask?
答:不幸的是,create_block_mask
从内存和计算角度来看都相当昂贵,因为确定一个块是否完全稀疏需要评估该块中每个点的 mask_mod
。有几种方法可以解决这个问题。
- 如果你的掩码在批次大小或注意力头维度上是相同的,请确保对这些维度进行广播(即在
create_block_mask
中将它们设置为None
)。 - 编译
create_block_mask
。不幸的是,目前torch.compile
由于一些限制,尚无法直接作用于create_block_mask
。但是,你可以设置_compile=True
,这将显著降低峰值内存和运行时(通常是一个数量级)。 -
编写自定义的 BlockMask 构造函数。BlockMask 的元数据非常简单(请参阅文档)。它本质上是两个张量。
a.num_blocks
:为每个查询块计算的 KV 块数量。 b.indices
:为每个查询块计算的 KV 块的位置。例如,这里有一个针对
causal_mask
的自定义 BlockMask 构造函数。
def create_causal_mask(S):
BLOCK_SIZE = 128
# The first query block computes one block, the second query block computes 2 blocks, etc.
num_blocks = torch.arange(S // BLOCK_SIZE, device="cuda") + 1
# Since we're always computing from the left to the right,
# we can use the indices [0, 1, 2, ...] for every query block.
indices = torch.arange(S // BLOCK_SIZE, device="cuda").expand(
S // BLOCK_SIZE, S // BLOCK_SIZE
)
num_blocks = num_blocks[None, None, :]
indices = indices[None, None, :]
return BlockMask(num_blocks, indices, BLOCK_SIZE=BLOCK_SIZE, mask_mod=causal_mask)
问:为什么 score_mod
和 mask_mod
是分开的?mask_mod
不就是一个特殊的 score_mod
吗?
答:非常敏锐的问题,假想的听众朋友!事实上,任何 mask_mod
都可以轻松转换为 score_mod
(我们不建议在实践中使用此函数!)。
def mask_mod_as_score_mod(b, h, q_idx, kv_idx):
return torch.where(mask_mod(b, h, q_idx, kv_idx), score, -float("inf"))
那么,如果 score_mod
可以实现 mask_mod
可以实现的一切,那么存在 mask_mod
的意义何在?
一个直接的挑战是:score_mod
需要实际的 score
值作为输入,但在预计算 BlockMask 时,我们没有实际的 score
值。我们或许可以通过传入全零来伪造值,如果 score_mod
返回 -inf
,则认为它被掩码(事实上,我们最初就是这样做的!)。
然而,这里有两个问题。首先,这种做法很取巧(hacky)——如果用户的 score_mod
在输入为 0 时返回 -inf
怎么办?或者如果用户的 score_mod
使用一个很大的负值而不是 -inf
进行掩码怎么办?这似乎是试图将一个圆钉强行塞进一个方孔。然而,将 mask_mod
与 score_mod
分开还有一个更重要的原因——它从根本上来说更有效率!
事实证明,对计算出的每个元素应用掩码实际上相当昂贵——我们的基准测试显示性能下降约 15-20%!因此,尽管通过跳过一半计算可以获得显著加速,但由于需要对每个元素进行掩码,我们损失了加速效果中的很大一部分!
幸运的是,如果我们可视化因果掩码,我们会注意到绝大多数块根本不需要“因果掩码”——它们是完全计算的!只有对角线上的块,部分计算且部分被掩码,才需要应用掩码。
BlockMask 之前告诉我们哪些块需要计算,哪些块可以跳过。现在,我们进一步增强此数据结构,以告诉我们哪些块是“完全计算的”(即可以跳过掩码)与“部分计算的”(即需要应用掩码)。然而,请注意,尽管在“完全计算的”块上可以跳过掩码,但其他 score_mod
,如相对位置编码,仍然需要应用。
仅给定一个 score_mod
,我们无法准确判断其中哪些部分是“掩码”。因此,用户必须自行将其拆分为 mask_mod
。
问:BlockMask 需要多少额外内存?
答:BlockMask 元数据的大小为 [BATCH_SIZE, NUM_HEADS, QUERY_LEN//BLOCK_SIZE, KV_LEN//BLOCK_SIZE]。
如果掩码在批次大小或注意力头维度上是相同的,可以在该维度上进行广播以节省内存。
在默认 BLOCK_SIZE
为 128 时,对于大多数用例,我们预计内存使用量可以忽略不计。例如,对于序列长度为 100 万的情况,BlockMask 只会使用 60MB 的额外内存。如果这是一个问题,你可以增加块大小:create_block_mask(..., BLOCK_SIZE=1024)。
例如,将 BLOCK_SIZE
增加到 1024 将使此元数据降至不到 1MB。
问:数值计算结果如何比较?
答:尽管结果并非完全按位相同,但我们确信 FlexAttention 与 FlashAttention 一样具有数值精度。我们在各种输入(包括因果和非因果注意力变体)上生成了 FlashAttention 与 FlexAttention 比较的差异分布。误差几乎相同。
性能
一般来说,FlexAttention 的性能几乎与手写的 Triton 核函数一样,这并不奇怪,因为我们大量利用了手写的 Triton 核函数。然而,由于其通用性,我们确实会产生一小部分性能开销。例如,我们必须承担一些额外的延迟来确定下一个计算哪个块。在某些情况下,我们提供了一些核函数选项,这些选项可以在改变其行为的同时影响核函数的性能。这些选项可以在此处找到:性能旋钮(performance knobs)。
作为案例研究,让我们探讨一下这些旋钮如何影响因果注意力的性能。我们将比较 Triton 核函数与 FlashAttentionv2 在 A100 上的性能。脚本可以在此处找到。
FlexAttention 在前向传播中达到 FlashAttention2 性能的 90%,在反向传播中达到 85%。FlexAttention 目前使用确定性算法,重新计算的中间结果比 FAv2 多,但我们计划改进 FlexAttention 的反向算法,并希望缩小这一差距!
结论
希望你在使用 FlexAttention 时也能像我们开发它时一样充满乐趣!在开发过程中,我们发现了比预期多得多的 API 应用。我们已经看到它将 torchtune 的样本打包吞吐量提高了 71%,使研究人员无需花费一周多时间编写自己的自定义 Triton 核函数,并提供了与自定义手写注意力变体相当的性能。
此外,实现 FlexAttention 的另一个乐趣在于,我们能够以有趣的方式利用许多现有的 PyTorch 基础设施。例如,TorchDynamo (torch.compile
的前端) 的一个独特之处在于,它*不*要求编译函数中使用的张量必须明确作为输入传入。这使得我们可以编译文档掩码等模块,而这些模块需要访问*全局*变量,并且全局变量需要改变!
bias = torch.randn(1024, 1024)
def score_mod(score, b, h, q_idx, kv_idx):
return score + bias[q_idx][kv_idx] # The bias tensor can change!
此外,torch.compile
是一个通用的图捕获机制,这一事实也使其能够支持更“高级”的转换,例如将任意 mask_mod
转换为适用于不规则张量的高阶转换。
我们还利用了 TorchInductor (torch.compile
的后端) 的基础设施,特别是 Triton 模板。这不仅使代码生成 FlexAttention 变得容易,还自动为我们提供了动态形状以及后处理融合(epilogue fusion,即将一个算子融合到注意力计算的末尾)的支持!未来,我们计划扩展此支持,以涵盖注意力量化版本或类似 RadixAttention 的功能。
此外,我们还利用了高阶算子(higher order ops)、PyTorch 的 autograd 来自动生成反向传播,以及 vmap 来自动应用 score_mod
以创建 BlockMask。
当然,如果没有 Triton 和 TorchInductor 生成 Triton 代码的能力,这个项目是不可能实现的。
我们期待未来将这里使用的方法应用于更多场景!
限制与未来工作
- FlexAttention 目前可在 PyTorch nightly 版本中使用,我们计划在 2.5.0 版本中将其作为原型功能发布。
- 我们在此未涉及如何在推理时使用 FlexAttention(或如何实现 PagedAttention)——我们将在后续文章中介绍这些内容。
- 我们正在努力提高 FlexAttention 的性能,使其在 H100 GPU 上与 FlashAttention3 相当。
- FlexAttention 要求所有序列长度必须是 128 的倍数——这将在近期解决。
- 我们计划尽快添加 GQA 支持——目前,你可以直接复制 KV 头。
致谢
我们想重点介绍一些启发 FlexAttention 的先前工作(和人物)。
- Tri Dao 在 FlashAttention 方面的工作
- Francisco Massa 和 Xformers 团队在 Triton 中的 BlockSparseAttention 工作
- Jax 团队在 SplashAttention 方面的工作
- Philippe Tillet 和 Keren Zhou 在 Triton 方面给予的帮助
- Ali Hassani 关于邻域注意力(neighborhood attention)的讨论
- 那些抱怨注意力核函数不支持他们喜欢的注意力变体的人们 :)