博客

FlexAttention:PyTorch 的灵活性与 FlashAttention 的性能

a cartoon chart flexing his muscles

理论上,“Attention is All You Need”。然而在实践中,我们还需要像 FlashAttention 这样经过优化的注意力机制实现。

尽管这些融合的注意力实现极大地提升了性能并支持长上下文,但这种效率是以牺牲灵活性为代价的。你无法再通过编写几个 PyTorch 算子来尝试新的注意力变体——通常你需要编写一个新的自定义内核(Kernel)!这对于机器学习研究人员来说,相当于一场“软件彩票”——如果你的注意力变体无法适配现有的优化内核,那么你注定要面临缓慢的运行时间和 CUDA 显存溢出(OOM)。

一些注意力变体的例子包括:因果注意力(Causal)、相对位置编码Alibi滑动窗口注意力PrefixLM文档掩码/样本打包/锯齿状张量(Jagged Tensors)Tanh 软截断(Soft-Capping)PagedAttention 等。更糟糕的是,人们往往需要组合这些变体!滑动窗口注意力 + 文档掩码 + 因果注意力 + 上下文并行?或者 PagedAttention + 滑动窗口 + Tanh 软截断?

下图左侧展示了当下的状况——掩码、偏置和各种设置的某些组合已经有了现成的内核实现。但大量的选项导致了呈指数级增长的设置组合,因此最终我们得到的支持非常零散。更糟糕的是,研究人员提出的新注意力变体将获得“零”支持。

Attention variant support diagram

为了彻底解决这个超立方体问题,我们推出了一个新的 PyTorch API:FlexAttention

  1. 我们提供了一个灵活的 API,允许用寥寥几行地道的 PyTorch 代码实现多种注意力变体(包括目前博文中提到的所有变体)。
  2. 我们通过 torch.compile 将其降级(lower)为融合的 FlashAttention 内核,生成的 FlashAttention 内核不会占用任何额外的内存,且性能与手写内核旗鼓相当。
  3. 我们还利用 PyTorch 的自动微分(autograd)机制,自动生成了反向传播过程。
  4. 最后,我们还可以利用注意力掩码中的稀疏性,从而相比标准注意力实现获得显著的性能提升。

有了 FlexAttention,我们希望尝试新的注意力变体时,唯一的限制就是你的想象力。

你可以在 Attention Gym 中找到许多 FlexAttention 的示例:https://github.com/pytorch-labs/attention-gym。如果你有任何酷炫的应用,欢迎提交示例!

附:我们也觉得这个 API 非常令人兴奋,因为它以一种有趣的方式利用了大量的现有 PyTorch 基础设施——文末会有更多相关内容。

FlexAttention

这是经典的注意力公式:

代码形式:

Q, K, V: Tensor[batch_size, num_heads, sequence_length, head_dim]
score: Tensor[batch_size, num_heads, sequence_length, sequence_length] = (Q @ K) / sqrt(head_dim)
probabilities = softmax(score, dim=-1)
output: Tensor[batch_size, num_heads, sequence_length, head_dim] = probabilities @ V

FlexAttention 允许用户自定义一个函数 score_mod

math equation

代码形式:

Q, K, V: Tensor[batch_size, num_heads, sequence_length, head_dim]
score: Tensor[batch_size, num_heads, sequence_length, sequence_length] = (Q @ K) / sqrt(head_dim)
modified_scores: Tensor[batch_size, num_heads, sequence_length, sequence_length] = score_mod(score)
probabilities = softmax(modified_scores, dim=-1)
output: Tensor[batch_size, num_heads, sequence_length, head_dim] = probabilities @ V

该函数允许你在 softmax 之前“修改”注意力得分。令人惊讶的是,这对于绝大多数注意力变体来说已经足够了(示例如下)!

具体来说,score_mod 的预期签名是相当独特的:

def score_mod(score: f32[], b: i32[], h: i32[], q_idx: i32[], kv_idx: i32[])
    return score # noop - standard attention

换句话说,score 是一个标量 PyTorch 张量,代表查询(query)标记和键(key)标记的点积。其余参数用于标识你当前正在计算的点积属于“哪一个”——b(批次中的当前元素)、h(当前头)、q_idx(查询中的位置)、kv_idx(键/值张量中的位置)。

要应用此函数,我们可以这样实现:

for b in range(batch_size):
    for h in range(num_heads):
        for q_idx in range(sequence_length):
            for kv_idx in range(sequence_length):
                modified_scores[b, h, q_idx, kv_idx] = score_mod(scores[b, h, q_idx, kv_idx], b, h, q_idx, kv_idx)

当然,这不是 FlexAttention 底层的实现方式。利用 torch.compile,我们自动将你的函数降级为一个单一的融合 FlexAttention 内核——无效退款保证!

该 API 的表达能力出人意料地强大。让我们看看一些示例。

Score Mod 示例

全注意力(Full Attention)

我们先来看看“全注意力”,即标准的双向注意力。在这种情况下,score_mod 是一个空操作(no-op)——它接收输入分数并原样返回。

def noop(score, b, h, q_idx, kv_idx):
    return score

要端到端地使用它(包括前向和反向传播):

from torch.nn.attention.flex_attention import flex_attention

flex_attention(query, key, value, score_mod=noop).sum().backward()

相对位置编码(Relative Position Encodings)

一种常见的注意力变体是“相对位置编码”。它不是在查询和键中编码绝对距离,而是根据查询和键之间的“距离”来调整分数。

def relative_positional(score, b, h, q_idx, kv_idx):
    return score + (q_idx - kv_idx)

注意,与典型实现不同,这不需要显式创建一个 SxS 的张量。相反,FlexAttention 在内核内“即时(on the fly)”计算偏置值,从而带来显著的内存和性能提升。

ALiBi 偏置

alibi bias

来源:Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation

ALiBi 引入于《Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation》,声称在推理的长度外推方面具有有益特性。值得注意的是,MosaicML 指出“缺乏内核支持”是他们最终从 ALiBi 转向旋转位置编码(rotary embeddings)的主要原因。

Alibi 与相对位置编码类似,只有一个区别——它有一个通常预先计算好的逐头因子(per-head factor)。

alibi_bias = generate_alibi_bias() # [num_heads]

def alibi(score, b, h, q_idx, kv_idx):
    bias = alibi_bias[h] * (kv_idx - q_idx)
    return score + bias

这展示了 torch.compile 提供的灵活性——即使 alibi_bias 没有作为输入显式传递,我们也可以从中加载数据!生成的 Triton 内核将计算从 alibi_bias 张量的正确加载并进行融合。请注意,即使你重新生成 alibi_bias,我们也无需重新编译。

软截断(Soft-capping)

软截断是一种用于 Gemma2 和 Grok-1 的技术,用于防止 Logits 过大。在 FlexAttention 中,它看起来像这样:

softcap = 20
def soft_cap(score, b, h, q_idx, kv_idx):
    score = score / softcap
    score = torch.tanh(score)
    score = score * softcap
    return score

注意,我们也在这里从前向传播自动生成了反向传播。此外,虽然此实现语义正确,但出于性能考虑,我们可能需要在此处使用 tanh 近似。更多详情请参见 attention-gym

因果掩码(Causal Mask)

虽然双向注意力是最简单的,但最初的《Attention is All You Need》论文和绝大多数 LLM 在仅解码器(decoder-only)设置中使用注意力,其中每个标记只能关注其之前的标记。人们通常将其视为下三角掩码,但在 score_mod API 下,它可以表示为:

def causal_mask(score, b, h, q_idx, kv_idx):
    return torch.where(q_idx >= kv_idx, score, -float("inf"))

基本逻辑是:如果查询标记在键标记“之后”,则保留分数。否则,通过将其设置为 -inf 来将其掩盖,从而确保它不会参与 softmax 计算。

然而,掩码与其他修改相比是特殊的——如果某些内容被屏蔽,我们可以完全跳过其计算!在这种情况下,因果掩码具有约 50% 的稀疏性,因此如果不利用这种稀疏性将导致 2 倍的减速。虽然 score_mod 足以“正确”地实现因果掩码,但要获得稀疏性的性能优势,还需要另一个概念——mask_mod

Mask Mods

为了利用掩码带来的稀疏性,我们需要做更多工作。具体而言,通过将 mask_mod 传递给 create_block_mask,我们可以创建一个 BlockMask。FlexAttention 随后可以使用 BlockMask 来利用这种稀疏性!

mask_mod 的签名与 score_mod 非常相似,只是没有 score 参数。具体如下:

# returns True if this position should participate in the computation
mask_mod(b, h, q_idx, kv_idx) => bool

请注意,score_mod 的表达能力严格强于 mask_mod。然而,对于掩码操作,建议使用 mask_modcreate_block_mask,因为它的性能更高。关于为什么 score_modmask_mod 是分开的,请参阅 FAQ。

现在,让我们看看如何使用 mask_mod 实现因果掩码:

因果掩码(Causal Mask)

from torch.nn.attention.flex_attention import create_block_mask

def causal(b, h, q_idx, kv_idx):
    return q_idx >= kv_idx

# Because the sparsity pattern is independent of batch and heads, we'll set them to None (which broadcasts them) 
block_mask = create_block_mask(causal, B=None, H=None, Q_LEN=1024, KV_LEN=1024)
# In this case, we don't need a score_mod, so we won't pass any in.
# However, score_mod can still be combined with block_mask if you need the additional flexibility.
flex_attention(query, key, value, block_mask=block_mask)

请注意,create_block_mask 是一个相对昂贵的操作!虽然 FlexAttention 在其变化时无需重新编译,但如果你不注意缓存,它可能会导致显著的减速(查看 FAQ 以获取最佳实践建议)。

虽然 TFlops 大致相同,但 mask_mod 版本的执行速度快了 2 倍!这证明了我们可以在损失硬件效率的情况下,利用 BlockMask 提供的稀疏性。

滑动窗口 + 因果(Sliding Window + Causal)

Sliding Window Causal diagrams

来源:Mistral 7B

Mistral 推崇,滑动窗口注意力(也称为局部注意力)利用了“最近的标记最有用”的直觉。特别是,它允许查询标记仅关注最近的(例如)1024 个标记。这通常与因果注意力结合使用。

SLIDING_WINDOW = 1024

def sliding_window_causal(b, h, q_idx, kv_idx):
    causal_mask = q_idx >= kv_idx
    window_mask = q_idx - kv_idx <= SLIDING_WINDOW 
    return causal_mask & window_mask

# If you want to be cute...
from torch.nn.attention import and_masks

def sliding_window(b, h, q_idx, kv_idx)
    return q_idx - kv_idx <= SLIDING_WINDOW

sliding_window_causal = and_masks(causal_mask, sliding_window)

我们将它与使用滑动窗口掩码的 F.scaled_dot_product_attention 以及带有因果掩码的 FA2(作为性能参考)进行了基准测试。我们不仅比 F.scaled_dot_product_attention 快得多,而且比带有因果掩码的 FA2 也要快得多,因为该掩码具有显著更高的稀疏性。

execution time charts

PrefixLM

PrefixLM diagram

来源:PaliGemma: A versatile 3B VLM for transfer

《Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer》提出的 T5 架构描述了一种注意力变体,它在“前缀(prefix)”上执行全双向注意力,而在其余部分执行因果注意力。我们再次组合两个掩码函数来实现这一点,一个用于因果掩码,另一个基于前缀长度。

prefix_length: [B]
def prefix_mask(b, h, q_idx, kv_idx):
    return kv_idx <= prefix_length[b]

prefix_lm_causal = or_masks(prefix_mask, causal_mask)
# In this case, our mask is different per sequence so we set B equal to our batch size
block_mask = create_block_mask(prefix_lm_causal, B=B, H=None, S, S)

就像 score_mod 一样,mask_mod 允许我们引用非函数显式输入的额外张量!然而,对于 prefixLM,稀疏模式随输入而变化。这意味着对于每个新的输入批次,我们都需要重新计算 BlockMask。一种常见的模式是在模型开头调用 create_block_mask,并在模型的所有注意力调用中重用该 block_mask。请参阅“重新计算块掩码 vs. 重新编译”。

然而,作为交换,我们不仅能够为 prefixLM 拥有一个高效的注意力内核,还能够利用输入中存在的任何稀疏性!FlexAttention 会根据 BlockMask 数据动态调整其性能,无需重新编译内核。

文档掩码/锯齿状序列(Document Masking/Jagged Sequences)

另一种常见的注意力变体是文档掩码/锯齿状序列。想象你有多个长度不一的序列。你希望将它们放在一起训练,但不幸的是,大多数算子只接受矩形张量。

通过 BlockMask,我们也可以在 FlexAttention 中高效地支持这一点!

  1. 首先,我们将所有序列平铺成一个包含 sum(sequence lengths) 个标记的单一序列。
  2. 然后,我们计算每个标记所属的 document_id。
  3. 最后,在 mask_mod 中,我们只需判断查询和 KV 标记是否属于同一个文档即可!
# The document that each token belongs to.
# e.g. [0, 0, 0, 1, 1, 2, 2, 2, 2, 2, 2] corresponds to sequence lengths 3, 2, and 6.
document_id: [SEQ_LEN]

def document_masking(b, h, q_idx, kv_idx):
    return document_id[q_idx] == document_id[kv_idx]

就这样!在这种情况下,我们最终得到的是一个块对角线(block-diagonal)掩码。

blockdiagonal mask

关于文档掩码,一个有趣的方面是,它很容易与任意其他掩码组合。例如,我们已经在上一节中定义了 prefixlm_mask。我们现在还需要定义一个 prefixlm_document_mask 函数吗?

在这些情况下,我们发现一种非常实用的模式,称为“高阶修改(higher level modification)”。在这种情况下,我们可以采用现有的 mask_mod 并自动将其转换为适用于锯齿状序列的掩码!

def generate_doc_mask_mod(mask_mod, document_id):
    # Get unique document IDs and their counts
    _, counts = torch.unique_consecutive(document_id, return_counts=True)
    # Create cumulative counts (offsets)
    offsets = torch.cat([torch.tensor([0], device=document_id.device), counts.cumsum(0)[:-1]])
    def doc_mask_wrapper(b, h, q_idx, kv_idx):
        same_doc = document_id[q_idx] == document_id[kv_idx]
        q_logical = q_idx - offsets[document_id[q_idx]]
        kv_logical = kv_idx - offsets[document_id[kv_idx]]
        inner_mask = mask_mod(b, h, q_logical, kv_logical)
        return same_doc & inner_mask
    return doc_mask_wrapper

例如,给定上述 prefix_lm_causal 掩码,我们可以将其转换为适用于打包文档的掩码:

prefix_length = torch.tensor(2, dtype=torch.int32, device="cuda")
def prefix_mask(b, h, q_idx, kv_idx):
    return kv_idx < prefix_length
prefix_lm_causal = or_masks(prefix_mask, causal_mask)
doc_prefix_lm_causal_mask = generate_doc_mask_mod(prefix_lm_causal, document_id)

现在,这个掩码的形状就是“块前缀LM对角线”形的。🙂

这就是我们的所有示例!注意力变体的数量远超我们的列举范围,所以请查阅 Attention Gym 获取更多示例。我们希望社区也能贡献一些他们最喜欢的 FlexAttention 应用。

常见问题解答 (FAQ)

问:FlexAttention 何时需要重新编译?

由于 FlexAttention 利用 torch.compile 进行图捕获,它实际上可以在广泛的情况下避免重新编译。值得注意的是,即使捕获的张量改变了数值,它也不需要重新编译!

flex_attention = torch.compile(flex_attention)
def create_bias_mod(bias)
    def bias_mod(score, b, h, q_idx, kv_idx):
        return score + bias
    return bias_mod
bias_mod1 = create_bias_mod(torch.tensor(0))
flex_attention(..., score_mod=bias_mod1) # Compiles the kernel here 

bias_mod2 = create_bias_mod(torch.tensor(2))
flex_attention(..., score_mod=bias_mod2) # Doesn't need to recompile! 

即使更改块稀疏性也不需要重新编译。但是,如果块稀疏性发生了变化,我们确实需要重新计算 BlockMask。

问:我们应该何时重新计算 BlockMask?

每当块稀疏性发生变化时,我们都需要重新计算 BlockMask。虽然计算 BlockMask 比重新编译便宜得多(耗时在数百微秒级别,而不是秒级),但你仍然应该注意不要过度频繁地重新计算。

以下是一些常见模式以及我们关于如何处理它们的建议:

掩码从不改变(例如因果掩码)
在这种情况下,你可以简单地预计算块掩码并将其全局缓存,在所有注意力调用中重复使用。

block_mask = create_block_mask(causal_mask, 1, 1, S,S)
causal_attention = functools.partial(flex_attention, block_mask=block_mask)

掩码每个批次都会改变(例如文档掩码)
在这种情况下,我们建议在模型开头计算 BlockMask 并将其贯穿传递到模型中——在所有层中重用 BlockMask。

def forward(self, x, doc_mask):
    # Compute block mask at beginning of forwards
    block_mask = create_block_mask(doc_mask, None, None, S, S)    
    x = self.layer1(x, block_mask)
    x = self.layer2(x, block_mask)
    ...
    # amortize block mask construction cost across all layers
    x = self.layer3(x, block_mask) 
    return x

掩码每一层都会改变(例如数据依赖的稀疏性)
这是最困难的情况,因为我们无法在多次 FlexAttention 调用中摊销块掩码的计算。虽然 FlexAttention 在这种情况下肯定仍然有效,但 BlockMask 的实际收益取决于你的注意力掩码有多稀疏以及我们构造 BlockMask 的速度。这引出了下一个问题……

问:我们如何更快地计算 BlockMask?

create_block_mask 不幸的是相当昂贵,无论是从内存还是计算的角度来看,因为确定一个块是否完全稀疏需要在块中的每一个点上评估 mask_mod。有几种方法可以解决这个问题:

  1. 如果你的掩码在批次大小或头数上是相同的,请确保你正在进行广播(即在 create_block_mask 中将它们设置为 None)。
  2. 编译 create_block_mask。不幸的是,目前由于一些限制,torch.compile 不能直接在 create_block_mask 上工作。但是,你可以设置 _compile=True,这将在我们的测试中显著降低峰值内存和运行时间(通常是一个数量级)。
  3. 为 BlockMask 编写自定义构造函数。BlockMask 的元数据非常简单(请参阅文档)。它本质上是两个张量:a. num_blocks:为每个查询块计算的 KV 块数量。
    b. indices:为每个查询块计算的 KV 块的位置。例如,这是 causal_mask 的自定义 BlockMask 构造函数。
def create_causal_mask(S):
    BLOCK_SIZE = 128
    # The first query block computes one block, the second query block computes 2 blocks, etc.
    num_blocks = torch.arange(S // BLOCK_SIZE, device="cuda") + 1
    # Since we're always computing from the left to the right,
    # we can use the indices [0, 1, 2, ...] for every query block.
    indices = torch.arange(S // BLOCK_SIZE, device="cuda").expand(
        S // BLOCK_SIZE, S // BLOCK_SIZE
    )
    num_blocks = num_blocks[None, None, :]
    indices = indices[None, None, :]
    return BlockMask(num_blocks, indices, BLOCK_SIZE=BLOCK_SIZE, mask_mod=causal_mask)
问:为什么 score_modmask_mod 不同?mask_mod 不仅仅是 score_mod 的一个特例吗?

非常敏锐的问题!事实上,任何 mask_mod 都可以轻松转换为 score_mod(但在实践中我们不建议使用此函数!)。

def mask_mod_as_score_mod(b, h, q_idx, kv_idx):
    return torch.where(mask_mod(b, h, q_idx, kv_idx), score, -float("inf"))

那么,如果 score_mod 可以实现 mask_mod 的所有功能,为什么要分开设置 mask_mod 呢?

一个直接的挑战是:score_mod 需要实际的 score 值作为输入,但在我们预计算 BlockMask 时,我们并没有实际的 score 值。我们或许可以通过传入全零来模拟这些值,如果 score_mod 返回 -inf,我们就认为它被屏蔽了(事实上,我们最初就是这么做的!)。

然而,有两个问题。首先,这很 hacky——如果用户的 score_mod 在输入为 0 时返回了 -inf 怎么办?或者如果用户的 score_mod 使用大的负值而不是 -inf 进行屏蔽怎么办?这看起来是在强行塞入不匹配的组件。但是,将 mask_modscore_mod 分开还有一个更重要的原因——它从根本上更有效率!

事实证明,对计算出的每个元素应用掩码实际上非常昂贵——我们的基准测试显示性能下降了约 15-20%!所以,虽然我们可以通过跳过一半的计算获得显著的速度提升,但我们需要通过掩盖每个元素来抵消其中的一部分提升!

幸运的是,如果我们可视化因果掩码,我们会注意到绝大多数块根本不需要“因果掩码”——它们是完全计算的!只有对角线上的块(部分计算、部分掩盖)才需要应用掩码。

blockdiagonal mask

BlockMask 先前告诉我们哪些块需要计算,哪些块可以跳过。现在,我们进一步扩充了该数据结构,以告知哪些块是“完全计算的”(即可以跳过掩码)与“部分计算的”(即需要应用掩码)。但请注意,即使在“完全计算的”块上可以跳过掩码,其他 score_mod(如相对位置嵌入)仍然需要应用。

仅给定一个 score_mod,我们无法合理地判断哪些部分属于“掩码”。因此,用户必须将它们手动拆分为 mask_mod

问:BlockMask 需要多少额外内存?

BlockMask 元数据的大小为 [BATCH_SIZE, NUM_HEADS, QUERY_LEN//BLOCK_SIZE, KV_LEN//BLOCK_SIZE]。如果掩码在批次或头维度上相同,则可以在该维度上进行广播以节省内存。

在默认的 BLOCK_SIZE 为 128 时,我们预计对于大多数用例,内存使用量几乎可以忽略不计。例如,对于 100 万的序列长度,BlockMask 只会使用额外的 60MB 内存。如果这是一个问题,你可以增加块大小:create_block_mask(..., BLOCK_SIZE=1024)。例如,将 BLOCK_SIZE 增加到 1024 会将该元数据降至 1MB 以下。

问:数值精度如何比较?

虽然结果并非逐位完全相同,但我们确信 FlexAttention 在数值上与 FlashAttention 一样准确。我们在各种因果和非因果注意力变体的大量输入上比较了 FlashAttention 与 FlexAttention,并生成了以下差异分布。误差几乎完全一致。

distribution chart

性能

总的来说,FlexAttention 的性能几乎与手写的 Triton 内核相当,这并不奇怪,因为我们大量利用了手写的 Triton 内核。然而,由于其通用性,我们确实承担了一些性能损失。例如,我们必须承担一些额外的延迟来确定下一个要计算的块。在某些情况下,我们提供了一些内核选项,这些选项可以在改变行为的同时影响内核的性能。可以在这里找到它们:performance knobs

作为案例研究,让我们探讨这些旋钮如何影响因果注意力的性能。我们将比较在 A100 上 Triton 内核与 FlashAttentionv2 的性能。脚本可以在这里找到。

FlexAttention 在前向传播中达到了 FlashAttention2 性能的 90%,在反向传播中达到了 85%。FlexAttention 目前使用的是一种确定性算法,比 FAv2 重计算了更多的中间变量,但我们计划改进 FlexAttention 的反向算法,并希望缩小这一差距!

flexattention speed chart
flexattention speed chart

结论

我们希望你使用 FlexAttention 的过程和我们开发它时一样有趣!在开发过程中,我们发现该 API 的应用场景远超预期。我们已经看到它将 torchtune 的样本打包吞吐量提升了 71%,无需研究人员花费一周多时间编写自定义 Triton 内核,并提供了与自定义手写注意力变体相当的性能。

实现 FlexAttention 的另一个有趣之处在于,我们能够以有趣的方式利用大量的现有 PyTorch 基础设施。例如,TorchDynamo(torch.compile 的前端)的独特之处在于,它不需要显式地将编译函数中使用的张量作为输入传入。这使我们能够编译像文档掩码这样的模组,这些模组需要访问可能发生变化的全局变量!

bias = torch.randn(1024, 1024)
def score_mod(score, b, h, q_idx, kv_idx):
    return score + bias[q_idx][kv_idx] # The bias tensor can change!

此外,torch.compile 作为一种通用的图捕获机制,也使它能够支持更“高级”的转换,例如将任何 mask_mod 转换为适用于锯齿状张量的高阶转换。

我们还利用了 TorchInductor(torch.compile 的后端)的 Triton 模板基础设施。这不仅使支持 FlexAttention 的代码生成变得容易,还自动为我们提供了对动态形状以及尾部融合(epilogue fusion,即在注意力末尾融合算子)的支持!未来,我们计划扩展此支持,以允许量化版本的注意力或类似 RadixAttention 的功能。

此外,我们还利用了高阶算子、PyTorch 的自动微分来自动生成反向传播,以及 vmap 来自动应用 score_mod 以创建 BlockMask。

当然,如果没有 Triton 和 TorchInductor 生成 Triton 代码的能力,这个项目是不可能实现的。

我们期待在未来将我们在这里使用的方法应用到更多场景中!

限制与未来工作

  • FlexAttention 目前可在 PyTorch 夜间版本(nightly releases)中使用,我们计划在 2.5.0 版本中将其作为原型功能发布。
  • 我们在此未涵盖如何使用 FlexAttention 进行推理(或如何实现 PagedAttention)——我们将在以后的文章中讨论这些内容。
  • 我们正在努力提升 FlexAttention 的性能,以匹配 H100 GPU 上的 FlashAttention3。
  • FlexAttention 要求所有序列长度都是 128 的倍数——这个问题很快会得到解决。
  • 我们计划很快添加 GQA 支持——目前,你可以简单地复制 KV 头。

致谢

我们想要强调一些启发了 FlexAttention 的过往工作(以及相关人员):

  • Tri Dao 在 FlashAttention 上的工作
  • Francisco Massa 和 Xformers 团队在 Triton 中实现的 BlockSparseAttention
  • Jax 团队在 SplashAttention 上的工作
  • Philippe Tillet 和 Keren Zhou 对我们在 Triton 方面提供的帮助
  • Ali Hassani 关于邻域注意力(neighborhood attention)的讨论
  • 所有抱怨注意力内核不支持他们最喜欢的注意力变体的人 🙂