博客

FlexAttention 第二部分:用于推理的 FlexAttention

概述

在 PyTorch 2.5.0 版本中,我们为那些希望自定义注意力算子(kernel)却不想编写内核代码的机器学习研究人员推出了 FlexAttention (torch.nn.attention.flex_attention)。本篇博客介绍了我们针对推理优化的解码后端,支持 GQA 和 PagedAttention,并更新了包括性能调优指南和可训练偏置支持在内的多项功能。

如果您正在寻找一种在训练后/推理管道中轻松试用 FlexAttention 的方法,PyTorch 原生训练后库 torchtune 和推理代码库 gpt-fast 已经集成了 FlexAttention。欢迎试用!

我们很高兴地宣布,关于 FlexAttention 的论文已被定于 5 月 12 日至 15 日在加利福尼亚州圣克拉拉举行的 MLSys2025 会议录用并受邀进行演示。

标题:FlexAttention:一种生成优化注意力算子的编程模型。 海报

用于推理的 FlexAttention

摘要:当在极短的查询(query)上运行时,torch.compile 会将 flex_attention 降级为融合的 FlashDecoding 算子。

单一的融合注意力算子并不适用所有场景——尤其是在长上下文 LLM 推理中。

LLM 推理的解码阶段是一个迭代过程:每次生成一个 token,生成一个 N 个 token 的句子需要进行 N 次前向传播。幸运的是,每次迭代不需要对整个句子重新计算自注意力——先前计算出的 token 已被缓存,因此我们只需将新生成的 token 与缓存的上下文进行注意力计算。

这导致了一种独特的注意力模式:短查询序列(1 个 token)关注长键值(KV)缓存(上下文长度可达 128k)。针对正方形注意力算子(q_len ≈ kv_len)的传统优化在这里并不直接适用。这种模式对 GPU 内存利用率和占用率提出了新的挑战。我们构建了一个专门的 FlexDecoding 后端,通过引入 FlashDecoding 中的解码专用技术,针对长上下文 LLM 推理进行了优化。

FlexDecoding 是作为 torch.nn.attention.flex_attention 算子的替代后端实现的。当输入为短查询和长 KV 缓存时,flex_attention 会在即时编译(JIT)时自动切换到 FlexDecoding 后端。如果输入形状发生显著变化(例如从预填充阶段过渡到解码阶段),JIT 重新编译会为每个场景生成单独的算子。

flex_attention = torch.compile(flex_attention)

k_cache = torch.random(B, H, 16384, D) 
v_cache = torch.random(B, H, 16384, D)

...

# Prefill Phase: query shape = [B, H, 8000, D]
flex_attention(q_prefill, k_cache, v_cache, ...) # Uses FlexAttention backend optimized for prefill & training

# Decoding Phase: q_last_token shape = [B, H, 1, D]
flex_attention(q_last_token  , k_cache, v_cache, ...) # Recompiles with the FlexDecoding backend 

# decode 2 tokens at the same time: q_last_2_tokens shape = [B, H, 2, D]
flex_attention(q_last_2_tokens, k_cache, v_cache, ...) # No recompilation needed! Runs the decoding kernel again.

使用 KV 缓存

高效推理的关键优化之一是维护一个预分配的 KV 缓存,并在生成新 token 时原地(in place)更新。FlexDecoding 不强制使用带有特定 API 的 KV 缓存策略,而是允许用户自行定义和管理 KV 缓存。

与 FlexAttention 类似,FlexDecoding 接受用户定义的 mask_modscore_mod 函数。这些函数会在 softmax 操作之前修改注意力分数。

score_mod(score, b, h, q_idx, kv_idx) -> tensor # return updated score

分数(Score)是一个标量 PyTorch 张量,代表查询 token 和键 token 的点积。其余参数指定了正在计算的分数:

  • b:batch 索引
  • h:注意力头索引
  • q_idx:查询张量中的 token 位置
  • kv_idx:键/值张量中的 token 位置

在解码阶段,先前计算的 token 会被缓存,仅使用最新生成的 token(第 i 个)作为查询。对此单 token 查询的朴素因果掩码(causal mask)如下所示:

def causal(score, b, h, q_idx, kv_idx):
    return torch.where(q_idx >= kv_idx, score, -float("inf"))

这存在问题:新 token“saw”应该关注所有先前生成的 token,即“The cat sat on the mat and saw”,而不仅仅是 KV 缓存中的第一个条目。为了纠正这一点,score_mod 需要通过 i 来偏移 q_idx,从而实现精确解码。

为每个 token 创建一个新的 score_mod 以适应偏移量速度很慢,因为这意味着 FlexAttention 必须在每次迭代中针对不同的 score_mod 重新编译。相反,

我们将此 offset 定义为一个张量,并在每次迭代时增加其值。

offset = torch.tensor(i, "cuda")
def causal_w_offset(score, b, h, q_idx, kv_idx):
    return torch.where(q_idx + offset >= kv_idx, score, -float("inf"))

# Attend the i-th token
flex_attention(..., score_mod=causal_w_offset  ) # Compiles the kernel here 
...
# Attend the i+1-th token
offset = offset + 1 # Increment offset
flex_attention(..., score_mod=causal_w_offset ) # Doesn't need to recompile! 

值得注意的是,这里的 offset 成为了一个捕获的张量,如果 offset 的值发生变化,它不需要重新编译。

无需手动重写用于处理偏移量的 score_modmask_mod。我们可以通过一个通用的重写器来自动化此过程。

offset = torch.tensor(i, "cuda")

def get_score_mod_w_offset(score_mod: _score_mod_signature, _offset: tensor):
    def _score_mod(score, b, h, q, kv):
        return score_mod(score, b, h, q + _offset, kv)
    return _score_mod

def get_mask_mod_w_offset(mask_mod: _mask_mod_signature, _offset: tensor):
    def _mask_mod(b, h, q, kv):
        return mask_mod(b, h, q + _offset, kv)
    return _mask_mod

causal_w_offset = get_score_mod_w_offset(causal, offset)

用于推理的 BlockMask

我们还可以在推理中使用 BlockMask 来利用掩码稀疏性。其思路是在模型设置期间预先计算一次 BlockMask,并在解码过程中使用它的切片。

预计算 BlockMask

在设置期间,我们为 MAX_SEQ_LEN x MAX_SEQ_LEN 创建一个正方形 BlockMask。

from torch.nn.attention.flex_attention import create_block_mask

def causal_mask(b, h, q_idx, kv_idx):
    return q_idx >= kv_idx

block_mask = create_block_mask(causal_mask, B=None, H=None, Q_LEN=MAX_SEQ_LEN,KV_LEN=MAX_SEQ_LEN)

在解码过程中使用 BlockMask

对于第 i 个 token,我们使用掩码的一个切片。

block_offset = i // block_mask.BLOCK_SIZE[0]
block_mask_slice = block_mask[:, :, block_offset]

# don't forget to use the mask_mod with offset! 
block_mask_slice.mask_mod = get_mask_mod_w_offset(causal_mask)

性能

FlexDecoding 算子的性能与 FlashDecoding (FAKV) 持平,并显著优于 PyTorch 的 scaled_dot_product_attention (代码)。

与 gpt-fast 中的 SDPA 相比,FlexDecoding 将 LLaMa3.1-8B 的服务性能提升了 1.22 倍至 2.04 倍,将 LLaMa3.1-70B 的性能提升了 0.99 倍至 1.66 倍。(代码)

分页注意力 (Paged Attention)

vLLM 是流行的 LLM 服务引擎之一,得益于 PagedAttention 带来的高效内存管理。现有的 PagedAttention 实现需要专用的 CUDA 内核,在支持新兴注意力变体方面灵活性有限。在本节中,我们介绍了一种由 FlexAttention 和 torch.compile 启用的 PT2 原生 PagedAttention 实现。

PagedAttention 将 KV 缓存分散存储以减少内存碎片并支持更大的 batch size。没有 PagedAttention 时,来自同一请求的 KV 缓存存储在连续内存中,需要形状为 B x H x KV LEN x D 的 2 个张量。我们称之为逻辑 KV 缓存。在这里,KV_LEN 是 batch 中所有请求的最大序列长度。参照图 1(a),KV_LEN 为 9,因此所有请求必须填充到 9 个 token,导致大量的内存浪费。通过 PagedAttention,我们可以将每个请求分块为多个相同大小(page_size)的页面,并将这些页面分散到形状为 1 x H x max seq len x D 的物理 KV 缓存中(其中 max_seq_len = n_pages x page_size)。这避免了将请求填充到相同长度,从而节省了内存。具体而言,我们提供了一个 assign API,通过索引计算来更新 KV 缓存。

def assign(
    batch_idx: torch.Tensor,
    input_pos: torch.Tensor,
    k_val: torch.Tensor,
    v_val: torch.Tensor,
    k_cache: torch.Tensor,
    v_cache: torch.Tensor,
) -> None

在此 assign API 背后的机制是一个页表(page table),它是将逻辑 KV 缓存映射到物理 KV 缓存的张量:

[batch_idx, logical_page_idx] -> physical_page_idx

assign 接收 k_valv_val,并根据页表的映射关系分散存储到物理 KV 缓存中。

带有页表的 Paged Attention

一个自然的问题是:如何将 PagedAttention 与 FlexAttention 集成以支持多样化的注意力变体?一个朴素的想法是在使用 FlexAttention 计算之前物化(materialize)逻辑 KV 缓存。但这会导致冗余的内存拷贝和糟糕的性能。另一个想法是构建一个专用的 CUDA 或 Triton 内核用于分页注意力,类似于现有的 PagedAttention 实现。然而,这增加了大量的人力投入和代码复杂性。

相反,我们设计了一种融合的间接内存访问方式,通过根据页表转换逻辑块掩码(logical block mask)来实现。在 FlexAttention 中,我们利用 BlockMask 来识别逻辑块并跳过冗余计算。虽然 Paged Attention 增加了一层额外的间接内存访问,但我们可以进一步将逻辑块掩码转换为对应于页表的物理块掩码,如图 2 所示。我们的 PagedAttention 实现通过 torch.gather 调用提供了一个 convert_logical_block_mask 方法。

def convert_logical_block_mask(
    block_mask: BlockMask,
    batch_idx: Optional[torch.Tensor] = None,
) -> BlockMask

通过块掩码转换实现的 Paged Attention

剩下一个问题是如何重写用户指定的 mask_modscore_mod 以适配 PagedAttention。当用户指定这些修改时,他们使用的是逻辑索引,而不知道运行时维护的页表。以下代码展示了运行时自动转换的过程,这对于用物理 kv 索引重写用户指定的修改是必要的。new_mask_mod 将采用 physical_kv_idx,将其转换回 logical_kv_idx,并在 logical_kv_idx 上应用用户指定的 mask_mod 以获取正确的掩码。为了提高效率,我们维护了从 physical_kv_block 到 logical_kv_block 的映射(physical_to_logical)以促进转换。为了正确性,我们使用 torch.where 调用将边界外的块屏蔽为 False。在将来自多个请求的逻辑 KV 缓存批处理到同一个物理 KV 缓存后,物理块的数量远多于每个请求的逻辑块数量。因此,在块掩码转换期间,物理块可能没有对应于特定请求的逻辑块。通过使用 torch.where 将其屏蔽为 False,我们可以确保不同请求的数据不会相互干扰,从而保证正确性。同样,我们可以自动转换 score_mod

def get_mask_mod(mask_mod: Optional[_mask_mod_signature]) -> _mask_mod_signature:
    if mask_mod is None:
        mask_mod = noop_mask

    def new_mask_mod(
        b: torch.Tensor,
        h: torch.Tensor,
        q_idx: torch.Tensor,
        physical_kv_idx: torch.Tensor,
    ):
        physical_kv_block = physical_kv_idx // page_size
        physical_kv_offset = physical_kv_idx % page_size
        logical_block_idx = physical_to_logical[b, physical_kv_block]
        logical_kv_idx = logical_block_idx * page_size + physical_kv_offset
        return torch.where(
            logical_block_idx >= 0, mask_mod(b, h, q_idx, logical_kv_idx), False
        )

    return new_mask_mod

图 3 展示了 Paged Attention 的延迟表现 (代码)。总体而言,与仅使用 Flex Attention 相比,引入 Paged Attention 后带来的开销不到 5%。我们还观察到其性能与 Flash Attention v2 持平。一个 最小服务示例 进一步表明,在评估包含 1M 条 GPT-4 完成任务和 3.2M 条 GPT-3.5 完成任务的 OpenOrca 数据集 时,PagedAttention 可以支持高出 76 倍的 batch size。

Paged Attention:不同序列长度下的延迟表现

可训练偏置

FlexAttention 现在支持 score_mod 函数中的可训练参数。此功能使用户能够在 score_mod 实现中引用需要梯度的张量,并在训练期间通过这些参数自动反向传播梯度。

内存高效的梯度累积

FlexAttention 不会物化完整的注意力分数矩阵,而是使用原子加法(tl.atomic_add)来累积梯度。这种方法显著减少了内存使用,代价是在梯度计算中引入了一些非确定性。

处理广播操作

前向传播中的广播操作(例如 score + bias[h])在反向传播中需要特别考虑。当在 head 或其他维度内跨多个注意力分数广播张量时,我们需要将这些梯度约简(reduce)回原始张量形状。我们不通过物化完整的注意力分数矩阵来进行此约简,而是使用原子操作。虽然这会带来一些运行时开销,但它避免了大型中间张量的物化,从而保持了内存效率。

当前限制

目前的实现仅允许在 score_mod 函数中对每个输入张量进行单次读取。例如,bias[q_idx] + bias[kv_idx] 是不支持的,因为它读取了同一个张量两次。我们希望在将来移除此限制。

简单示例

bias = torch.randn(num_heads, requires_grad=True)
def score_mod(score, b, h, q_idx, kv_idx):
    return score + bias[h]  

FlexAttention 的性能调优

要点速览

为了获得最佳性能,请使用 max-autotune 编译 FlexAttention,特别是在处理复杂的 score_modsmask_mods 时:

flex_attention = torch.compile(flex_attention, dynamic=True, mode='max-autotune')

什么是 max-autotune

max-autotune 是一种 torch.compile 模式,TorchInductor 会在该模式下扫描许多内核参数(如分块大小、num_stages)并选择性能最佳的配置。此过程允许内核测试成功和失败的配置而不产生问题,并找到最佳的可行配置。

虽然使用 max-autotune 编译需要更长的时间,但最优配置会被缓存以供后续内核执行使用。

以下是使用 max-autotune 编译 FlexAttention 的示例:

triton_flex_attention_backward_7 0.2528 ms 100.0% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M1=32, BLOCK_M2=32, BLOCK_N1=32, BLOCK_N2=32, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=7, HAS_FULL_BLOCKS=False, IS_DIVISIBLE=False, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=128, ROWS_GUARANTEED_SAFE=False, SM_SCALE=0.08838834764831843, SPARSE_KV_BLOCK_SIZE=1073741824, SPARSE_Q_BLOCK_SIZE=1073741824, V_HEAD_DIM=128, num_stages=4, num_warps=4

为什么 FlexAttention 要使用 max-autotune

FlexAttention 中利用的共享内存量取决于 score_modmask_mod 方法。这种变异性意味着预配置的默认内核参数可能会在某些硬件上导致性能下降,甚至对于某些掩码/修改操作引发共享内存不足(out of shared memory)错误。

例如,使用文档掩码(document masks)时,默认配置可能会使 GPU 占用率减半,在某些 GPU 上将性能降低到其潜力的 ~75%。为避免此类问题,我们强烈建议启用 max-autotune

更新与增强

  • 现已作为 PyTorch 2.5.0 中的原型功能可用。
  • 修复了关键的正确性问题,包括影响同一 torch.compile 调用内多次调用 FlexAttention 的 bug。

扩展的架构支持

  • 支持任意序列长度——不再要求必须是 128 的倍数。
  • 通过 is_gqa=True 增加了原生分组查询注意力(GQA)支持。
  • 增强了维度的灵活性
    • 支持不同的 QK 和 V 头维度。
    • 支持非 2 的幂次方的头维度。
  • 可训练的注意力偏置(原型)。

底层改进

  • 新的融合 CPU 后端。
  • 改进了对 float32 输入的 TF32 处理。
  • 解决了各种动态形状问题。
  • 输出布局与查询步长(strides)匹配。

这些更新使得 FlexAttention 更加稳健和灵活,同时保持了其结合 PyTorch 的易用性与 FlashAttention 的性能优势这一核心承诺。