跳转到主要内容
博客

FlexAttention 第二部分:用于推理的 FlexAttention

概述

在 PyTorch 2.5.0 版本中,我们引入了 FlexAttention torch.nn.attention.flex_attention,专为希望在不编写内核代码的情况下自定义注意力内核的机器学习研究人员设计。本博客将介绍我们针对推理优化的解码后端,该后端支持 GQA 和 PagedAttention,以及性能调优指南和可训练偏置支持等功能更新。

如果您正在寻找一种在训练后/推理管道中轻松试用 FlexAttention 的方法,PyTorch 原生训练后库 torchtune 和推理代码库 gpt-fast 都已集成了 FlexAttention。快来试试吧!

我们很高兴地宣布,我们关于 FlexAttention 的论文已被 MLSys2025 会议接受,并将在 5 月 12 日至 15 日在加利福尼亚州圣克拉拉举行的会议上进行展示。

标题:FlexAttention:一种用于生成优化注意力内核的编程模型。 海报

用于推理的 FlexAttention

TL;DR:当在非常短的查询上运行时,torch.compile 会将 flex_attention 降低为融合的 FlashDecoding 内核。

一个融合的注意力内核无法满足所有需求——尤其是在长上下文 LLM 推理中。

LLM 推理的解码阶段是一个迭代过程:逐个生成 token,生成一个 N-token 句子需要 N 次前向传播。幸运的是,每次迭代都不需要重新计算整个句子的自注意力——之前计算的 token 会被缓存,因此我们只需要将新生成的 token 注意到缓存的上下文即可。

这导致了一种独特的注意力模式,其中一个短查询序列(1 个 token)注意到一个长键值缓存(上下文长度高达 128k)。传统的方块注意力内核优化(q_len ≈ kv_len)不直接适用于此处。这种模式对 GPU 内存利用率和占用率提出了新的挑战。我们构建了一个专用的 FlexDecoding 后端,针对长上下文 LLM 推理进行了优化,并结合了 FlashDecoding 中的解码特定技术。

FlexDecoding 作为 torch.nn.attention.flex_attention 运算符的替代后端实现。flex_attention 在给定短查询和长 KV 缓存时,会自动切换到 FlexDecoding 后端进行 JIT 编译。如果输入形状发生显著变化,例如从预填充阶段过渡到解码阶段,JIT 重编译会为每种情况生成一个单独的内核。

flex_attention = torch.compile(flex_attention)

k_cache = torch.random(B, H, 16384, D) 
v_cache = torch.random(B, H, 16384, D)

...

# Prefill Phase: query shape = [B, H, 8000, D]
flex_attention(q_prefill, k_cache, v_cache, ...) # Uses FlexAttention backend optimized for prefill & training

# Decoding Phase: q_last_token shape = [B, H, 1, D]
flex_attention(q_last_token  , k_cache, v_cache, ...) # Recompiles with the FlexDecoding backend 

# decode 2 tokens at the same time: q_last_2_tokens shape = [B, H, 2, D]
flex_attention(q_last_2_tokens, k_cache, v_cache, ...) # No recompilation needed! Runs the decoding kernel again.

使用 KV 缓存

提高推理效率的关键优化之一是维护一个预分配的 KV 缓存,该缓存会随着新 token 的生成而原地更新。FlexDecoding 不强制使用特定 KV 缓存策略和专用 API,而是允许用户自行定义和管理 KV 缓存。

与 FlexAttention 类似,FlexDecoding 接受用户定义的 mask_modscore_mod 函数。这些函数在 softmax 操作之前修改注意力分数。

score_mod(score, b, h, q_idx, kv_idx) -> tensor # return updated score

Score 是一个标量 pytorch 张量,表示查询 token 和键 token 的点积。其余参数指定正在计算哪个分数

  • b 批处理索引
  • h 注意力头索引
  • q_idx 查询张量中的 token 位置
  • kv_idx 键/值张量中的 token 位置

在解码阶段,先前计算的 token 会被缓存,并且只有最新生成的 token(第 i 个)用作查询。此单 token 查询上的简单因果掩码如下所示

def causal(score, b, h, q_idx, kv_idx):
    return torch.where(q_idx >= kv_idx, score, -float("inf"))

这有问题:新 token “saw” 应该注意到所有先前生成的 token,即 “The cat sat on the mat and saw”,而不仅仅是 kv 缓存中的第一个条目。为了纠正这一点,score_mod 需要将 q_idx 偏移 i 以实现准确解码。

为每个 token 创建一个新的 score_mod 来适应偏移量会很慢,因为它意味着 FlexAttention 需要在每次迭代中针对不同的 score_mod 进行重新编译。相反,

我们将此 offset 定义为一个张量,并在每次迭代中增加其值

offset = torch.tensor(i, "cuda")
def causal_w_offset(score, b, h, q_idx, kv_idx):
    return torch.where(q_idx + offset >= kv_idx, score, -float("inf"))

# Attend the i-th token
flex_attention(..., score_mod=causal_w_offset  ) # Compiles the kernel here 
...
# Attend the i+1-th token
offset = offset + 1 # Increment offset
flex_attention(..., score_mod=causal_w_offset ) # Doesn't need to recompile! 

值得注意的是,这里的 offset 成为一个捕获的张量,如果 offset 的值发生变化,它不需要重新编译。

无需手动重写 score_modmask_mod 以进行偏移处理。我们可以使用通用重写器自动执行此过程

offset = torch.tensor(i, "cuda")

def get_score_mod_w_offset(score_mod: _score_mod_signature, _offset: tensor):
    def _score_mod(score, b, h, q, kv):
        return score_mod(score, b, h, q + _offset, kv)
    return _score_mod

def get_mask_mod_w_offset(mask_mod: _mask_mod_signature, _offset: tensor):
    def _mask_mod(b, h, q, kv):
        return mask_mod(b, h, q + _offset, kv)
    return _mask_mod

causal_w_offset = get_score_mod_w_offset(causal, offset)

用于推理的 BlockMask

我们还可以使用 BlockMask 进行推理,以利用掩码稀疏性。其思想是在模型设置期间预先计算 BlockMask 一次,并在解码期间使用其切片

预计算 BlockMask

在设置期间,我们为 MAX_SEQ_LEN x MAX_SEQ_LEN 创建一个平方 BlockMask

from torch.nn.attention.flex_attention import create_block_mask

def causal_mask(b, h, q_idx, kv_idx):
    return q_idx >= kv_idx

block_mask = create_block_mask(causal_mask, B=None, H=None, Q_LEN=MAX_SEQ_LEN,KV_LEN=MAX_SEQ_LEN)

解码期间使用 BlockMask

对于第 i 个 token,我们使用掩码的一个切片

block_offset = i // block_mask.BLOCK_SIZE[0]
block_mask_slice = block_mask[:, :, block_offset]

# don't forget to use the mask_mod with offset! 
block_mask_slice.mask_mod = get_mask_mod_w_offset(causal_mask)

性能

FlexDecoding 内核的性能与 FlashDecoding (FAKV) 持平,并且显著优于 pytorch scaled_dot_product_attention (代码)。

与 gpt-fast 中的 SDPA 相比,FlexDecoding 将 LLaMa3.1-8B 的服务性能提升了 1.22x-2.04x,将 LLaMa3.1-70B 的性能提升了 0.99x – 1.66x。(代码

分页注意力

vLLM 是流行的 LLM 服务引擎之一,由 PagedAttention 的高效内存管理提供支持。现有的 PagedAttention 实现需要专用的 CUDA 内核,并且在支持新兴注意力变体方面显示出有限的灵活性。在本节中,我们将介绍一种 PT2 原生 PagedAttention 实现,该实现通过 flex attention 和 torch.compile 实现。

PagedAttention 散布 KV 缓存以减少内存碎片并支持更大的批处理大小。如果没有 PagedAttention,来自相同请求的 KV 缓存将存储在连续内存中,需要 2 个形状为 B x H x KV LEN x D 的张量。我们称之为逻辑 KV 缓存。这里,KV_LEN 是批处理中所有请求的最大序列长度。考虑到图 1(a),KV_LEN 为 9,因此所有请求都必须填充到 9 个 token,导致大量内存浪费。使用 PagedAttention,我们可以将每个请求分块为多个大小相同的页面 page_size,并将这些页面散布到形状为 1 x H x max seq len x D 的物理 KV 缓存中,其中 max_seq_len=n_pages x page_size。这避免了将请求填充到相同长度并节省了内存。具体来说,我们提供了一个 assign API,通过索引计算更新 KV 缓存

def assign(
    batch_idx: torch.Tensor,
    input_pos: torch.Tensor,
    k_val: torch.Tensor,
    v_val: torch.Tensor,
    k_cache: torch.Tensor,
    v_cache: torch.Tensor,
) -> None

assign API 背后是一个页表,一个将逻辑 KV 缓存映射到物理 KV 缓存的张量

[batch_idx, logical_page_idx] -> physical_page_idx

assign 接受 k_valv_val,并根据页表中的映射将其分散到物理 KV 缓存中。

使用页表的分页注意力

一个自然的问题是,如何将 PagedAttention 与 flex attention 集成以支持不同的注意力变体?一个朴素的想法是在使用 flex attention 计算之前将逻辑 KV 缓存物化。但这会导致冗余内存复制和性能不佳。另一个想法是为分页注意力构建专用的 CUDA 或 Triton 内核,类似于 现有 PagedAttention 实现。然而,这增加了大量手动工作和代码复杂性。

相反,我们设计了一种融合的间接内存访问,通过根据页表转换逻辑块掩码。在 FlexAttention 中,我们利用 BlockMask 识别逻辑块并跳过冗余计算。虽然 Paged Attention 增加了额外的间接内存访问层,但我们可以进一步将逻辑块掩码转换为与页表对应的物理块掩码,如图 2 所示。我们的 PagedAttention 实现通过 torch.gather 调用提供了一个 convert_logical_block_mask

def convert_logical_block_mask(
    block_mask: BlockMask,
    batch_idx: Optional[torch.Tensor] = None,
) -> BlockMask

通过块掩码转换的分页注意力

剩下的一個問題是如何為分頁注意力重寫用戶指定的 mask_modscore_mod。當用戶指定這些修改時,他們使用邏輯索引編寫,而不知道運行時維護的頁表。以下代碼顯示了運行時的自動轉換,這是使用物理 kv 索引重寫用戶指定的修改所必需的。new_mask_mod 將接受 physical_kv_idx 並將其轉換回 logical_kv_idx,然後對 logical_kv_idx 應用用戶指定的 mask_mod 以獲得正確的掩碼。為了提高效率,我們維護 physical_to_logical 作為從 physical_kv_block 到 logical_kv_block 的映射,以方便轉換。為了正確性,我們使用 torch.where 調用將超出邊界的塊屏蔽為 False。將來自多個請求的邏輯 KV 緩存批處理到同一個物理 KV 緩存後,物理塊的數量遠多於每個請求的邏輯塊數量。因此,在塊掩碼轉換期間,物理塊可能沒有對應的邏輯塊。通過使用 torch.where 屏蔽為 False,我們可以確保來自不同請求的數據不會相互干擾的正確性。同樣,我們可以自動轉換 score_mod

def get_mask_mod(mask_mod: Optional[_mask_mod_signature]) -> _mask_mod_signature:
    if mask_mod is None:
        mask_mod = noop_mask

    def new_mask_mod(
        b: torch.Tensor,
        h: torch.Tensor,
        q_idx: torch.Tensor,
        physical_kv_idx: torch.Tensor,
    ):
        physical_kv_block = physical_kv_idx // page_size
        physical_kv_offset = physical_kv_idx % page_size
        logical_block_idx = physical_to_logical[b, physical_kv_block]
        logical_kv_idx = logical_block_idx * page_size + physical_kv_offset
        return torch.where(
            logical_block_idx >= 0, mask_mod(b, h, q_idx, logical_kv_idx), False
        )

    return new_mask_mod

图 3 展示了 Paged Attention 的延迟(代码)。总体而言,Flex Attention 与 Paged Attention 相比,与仅使用 Flex Attention 相比,开销不到 5%。我们还观察到与 Flash Attention v2 持平的性能。一个最小服务示例进一步表明,在评估 OpenOrca 数据集时,PagedAttention 可以支持 76 倍的批处理大小,该数据集包括 1M GPT-4 完成和 3.2M GPT-3.5 完成。

分页注意力:不同序列长度下的延迟

可训练偏置

FlexAttention 现在支持在 score_mod 函数中包含可训练参数。此功能允许用户在其 score_mod 实现中引用需要梯度的张量,并在训练期间自动通过这些参数反向传播梯度。

内存高效的梯度累积

FlexAttention 不会物化完整的注意力分数矩阵,而是使用原子加法 (tl.atomic_add) 来累积梯度。这种方法显著降低了内存使用量,但代价是引入了梯度计算中的一些非确定性。

处理广播操作

前向传播中的广播操作(例如,score + bias[h])在反向传播中需要特殊考虑。当在一个头或其他维度中将张量广播到多个注意力分数时,我们需要将这些梯度还原为原始张量形状。我们不物化完整的注意力分数矩阵来执行此归约,而是使用原子操作。虽然这会带来一些运行时开销,但它允许我们通过避免物化大型中间张量来保持内存效率。

当前限制

当前实现仅允许从 score_mod 函数中的每个输入张量进行单次读取。例如,不支持 bias[q_idx] + bias[kv_idx],因为它两次读取同一张量。我们希望将来能取消此限制。

简单示例

bias = torch.randn(num_heads, requires_grad=True)
def score_mod(score, b, h, q_idx, kv_idx):
    return score + bias[h]  

FlexAttention 的性能调优

TL;DR

为了获得最佳性能,请使用 max-autotune 编译 FlexAttention,尤其是在处理复杂的 score_modsmask_mods

flex_attention = torch.compile(flex_attention, dynamic=True, mode=’max-autotune’)

什么是 max-autotune

max-autotunetorch.compile 的一种模式,在这种模式下,TorchInductor 会遍历许多内核参数(例如,tile size、num_stages)并选择性能最佳的配置。此过程允许内核测试成功和失败的配置而不会出现问题,并找到最佳可行配置。

虽然使用 max-autotune 编译需要更长时间,但最佳配置会缓存以供未来的内核执行使用。

以下是使用 max-autotune 编译的 FlexAttention 示例

triton_flex_attention_backward_7 0.2528 ms 100.0% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M1=32, BLOCK_M2=32, BLOCK_N1=32, BLOCK_N2=32, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=7, HAS_FULL_BLOCKS=False, IS_DIVISIBLE=False, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=128, ROWS_GUARANTEED_SAFE=False, SM_SCALE=0.08838834764831843, SPARSE_KV_BLOCK_SIZE=1073741824, SPARSE_Q_BLOCK_SIZE=1073741824, V_HEAD_DIM=128, num_stages=4, num_warps=4

为什么要为 FlexAttention 使用 max-autotune

FlexAttention 中共享内存的利用量取决于 score_modmask_mod 方法。这种可变性意味着预配置的默认内核参数可能导致性能急剧下降,甚至在某些硬件上对于某些掩码/模块出现共享内存不足错误。

例如,对于文档掩码,默认配置可以将 GPU 占用率减半,从而将某些 GPU 上的性能降低到其潜力的约 75%。为避免此类问题,我们强烈建议启用 max-autotune

更新与增强

  • 现已作为原型功能在 PyTorch 2.5.0 中提供
  • 修复了关键的正确性问题,包括影响在同一 torch.compile 调用中多次调用 FlexAttention 的错误

扩展了架构支持

  • 任意序列长度支持——不再需要 128 的倍数
  • 通过 is_gqa=True 添加了原生分组查询注意力 (GQA) 支持
  • 增强了维度灵活性
    • 不同的 QK 和 V 头维度
    • 非 2 的幂的头维度
  • 可训练注意力偏置(原型)

内部

  • 新的融合 CPU 后端
  • 改进了 float32 输入的 TF32 处理
  • 解决了各种动态形状问题
  • 输出布局与查询步幅匹配

这些更新使 FlexAttention 更加健壮和灵活,同时保持了其将 PyTorch 的易用性与 FlashAttention 的性能优势相结合的核心承诺。