博客

FlexAttention + FlashAttention-4:既快又灵活

要点速览

在 Hopper 和 Blackwell GPU 上,FlexAttention 现在拥有了 FlashAttention-4 后端。

我们在 PyTorch 中增加了对自动生成 CuTeDSL 分数/掩码修改函数的支持,并实现了针对自定义注意力变体的 FlashAttention-4 即时(JIT)实例化。

在计算密集型工作负载上,这使得性能相比现有的 Triton 实现提升了 1.2 倍至 3.2 倍。

FlexAttention 回顾

FlexAttention 是一个 PyTorch API,让你只需几行 Python 代码即可实现自定义注意力变体,无需编写 CUDA。你只需编写一个 score_modmask_mod 函数来修改注意力分数,编译器会处理剩下的工作:ALiBi、滑动窗口、文档掩码、软截断(soft-capping)及其组合均可通过同一接口实现。

从底层来看,它在原生 FlashAttention 之上进行了两个扩展:

  1. 对 Softmax 前分数的逐点修改,支持从全局内存进行任意加载。
  2. 前向和反向传播的块稀疏迭代,并使用简单的数据结构在运行时编码与数据相关的稀疏性。

就是这样。当然,魔鬼藏在细节中,但正如我们在原始 FlexAttention 博客推理版 FlexAttention 博客中所展示的,这两个扩展涵盖了广泛的主流注意力变体。

此次发布后,FlexAttention 现在拥有了 FlashAttention-4 (FA4) 后端。使用方法如下:

import torch
from functools import partial

from torch.nn.attention.flex_attention import flex_attention

flex_flash = torch.compile(
    partial(flex_attention, kernel_options={"BACKEND": "FLASH"}), dynamic=False
)

def local_boost(score, b_idx, h_idx, q_idx, kv_idx):
    return torch.where(torch.abs(q_idx - kv_idx) <= 8, score * 2, score)

B, H, S, D = 2, 8, 2048, 128
q = torch.randn(B, H, S, D, device="cuda", dtype=torch.bfloat16)
k = torch.randn(B, H, S, D, device="cuda", dtype=torch.bfloat16)
v = torch.randn(B, H, S, D, device="cuda", dtype=torch.bfloat16)
out = flex_flash(q, k, v, score_mod=local_boost)

BACKEND 设置为 "FLASH" 以使用 FA4 后端。你需要安装较新的 PyTorch nightly 版本和较新的 FlashAttention 代码库;请查看安装文档以了解版本兼容性。这是处于快速开发中的代码,在趋于稳定前可能会有一些破坏性变更。

用 FlexAttention 普及注意力研究

FlexAttention 最初设计(并命名)的初衷是为 AI 研究人员在原型设计和尝试新的注意力变体时提供灵活性。实践证明确实如此:已有数十篇论文引用了 FlexAttention,且超过一千个代码库采用了它:

尽管 Flex 在赋能研究人员方面取得了成功,但用户不断反馈的一个问题是,他们最终会遇到难以突破的性能上限。在最初发布博客文章时,我们将它与 Hopper GPU 上的 FlashAttention-3 (FA3) 进行了比较,性能大约是后者的 80%。

如果今天再进行测量,尽管两种实现都有改进,但 FlexAttention 的吞吐量仅约为 FlashAttention-3 的 60%!

一个常见的模式出现了:研究人员用 Flex 进行实验,找到可行方案后,一旦性能成为关键,就会撞上“墙”。此时,专家必须将其移植到底层实现中。FlashAttention-3 不断增加新参数以扩展功能,但每一个新参数/模式都需要底层重写。我们将负担从研究人员转移到了 ML 工程师身上!

在 Hopper 上,完全优化版本之间的性能差异或许值得用灵活性来交换,但在更新的 Blackwell GPU 上情况则完全不同。

让我们看看在 Blackwell GB200 GPU(1000W 功耗)上,使用现有 Triton 实现(开启所有自动调优开关)的 FlexAttention 与通过 PyTorch SDPA 可用的高度优化实现(如 cuDNN attention)之间的对比:

曾经微小的差距如今已变成一道鸿沟!

Blackwell:更大的张量核心,更大的挑战

在 Blackwell 上,高性能注意力需要深度流水线化、针对 Warp 优化的内核。这些技术在基于 Triton 的实现中无法表达。我们推荐阅读 逆向工程 FlashAttention-4 中的出色解释,它详细介绍了该实现如何利用 Blackwell 的新硬件功能,以及 softmax 计算方式的更新。游戏的规则一如既往:保持张量核心的繁忙;由于它们变得更快了,这需要大量使用深度异步流水线。

Blackwell 引入了张量内存(TMEM),这是一种程序员可管理的、靠近张量核心的暂存器,用于存储中间结果。更重要的是,数据移动和矩阵乘法现在都是完全异步的。Warp 可以启动矩阵乘法或加载操作并立即继续执行后续任务。

Warp 特化将工作拆分为多个阶段:一些 Warp 处理同步工作(如需要寄存器的 softmax),而另一些 Warp 通过发出加载和矩阵乘法请求来编排异步流水线,并协调同步。由于编排 Warp 的寄存器压力较低,更多的操作可以同时保持在运行状态,从而隐藏延迟。

张量核心变大变快了,但负责指数运算等特殊函数单元(SFU)的进步速度却没跟上。对于前向注意力,这改变了瓶颈:softmax 的 exp() 运算现在和矩阵乘法一样昂贵。为了让 GPU 保持完全饱和,你需要在这两个 tile 之间进行乒乓切换,将一个 tile 的矩阵乘法与另一个 tile 的指数运算重叠。下面的时间轴显示了这些阶段如何交替以隐藏延迟。

反向传播更加棘手。TMEM 中没有足够的空间同时容纳所有的累加器,因此内核需要精细的流水线化,以便在共享内存、寄存器和张量内存处于沉重压力下时,将计算与数据移动重叠。

这正是通用编译器难以轻易发现的底层编排。正如 Gluon 简介 所言:“虽然 Triton 编译器在为各种内核生成高效代码方面做得很好,但它仍会被手动调优的底层代码击败。在这种情况下,用户几乎无法做任何事来显著提高性能,因为所有的细节都被隐藏了。”对于 FlexAttention 来说这更难,因为它是一种元注意力实现,当模式由用户定义时,硬编码针对特定模式的编译器优化非常困难。正因如此,我们开始研究底层实现,以寻找提高 Blackwell 性能的最佳方式。

以 FlashAttention-4 为基础

在针对 Blackwell 新硬件的注意力实现方面,存在着很大的变数。cuDNN 较早地添加了高性能注意力支持,但 FA3(Hopper 上的现有 SOTA 实现)无法在 Blackwell 上工作。WGMMA 在 SM100 上已不再存在:它已被 TCGEN05 张量核心指令取代,且张量核心操作需要不同的内存空间(张量内存)。

Tri Dao 等人开始研究 FlashAttention-4,这是一个能够充分利用硬件的更新版本实现。

从 FA3 到 FA4 的一个重大变化是 CuTeDSL,这是 NVIDIA CUTLASS 团队最近发布的一个 Python DSL,用于使用 CUTLASS 抽象编写高性能 CUDA 内核。SOTA 注意力实现大量使用了 CUTLASS 抽象(如 cute.Layouts),但任何尝试过安装 FlashAttention 的人都知道这有多痛苦,且编译时间很长。因此,虽然之前有人提出过用 CUTLASS C++ 重写 Flex 的想法,但 FlexAttention 的动态特性(以及编译开销)使得这一前提不再具有吸引力。CuTeDSL 允许用 Python 编写过去需要 CUTLASS C++ 才能实现的功能,这使得 JIT 风格的工作流对 FlexAttention 来说更具实用性。

在与 Tri Dao 对此路径进行初步讨论后,我们决定联手为 FlexAttention 和 FlashAttention-4 共同推进这一实现。

我们没有构建一个单独的实现,而是通过协作直接扩展了 FA4,共享相同的异步流水线基础设施,并在 FlexAttention 需要注入分数修改和稀疏性的地方增加了扩展点。

这意味着在前向和反向传播中都增加了分数修改支持(通过将分数修改内联到 FlashAttention 实现中),并增加了对 FlexAttention 中使用的块稀疏元数据的支持。

这项工作大致分为两部分:FA4 的变更以生成 FlexAttention 模板实例化,以及 Inductor 的更新以从 PyTorch 表示生成所需的 CuTeDSL 代码。

Inductor → CuTeDSL:胶水层

那么我们需要为 FlexAttention 生成什么呢?逐点修改和任意加载。幸运的是,Inductor 并不是第一次做这件事,这种扩展已有现成的机制。例如,让我们看看一个可用于实现 ALiBi 的分数修改(有趣的是:这正是 FlexAttention 项目的推动示例)。

粗略地说,torch.compile 获取用户代码并将其通过多个 IR(中间表示)进行转换。这些转换会产生越来越底层的表示。在 FX IR 中,你仍然可以看到熟悉的 PyTorch 算子,以及在使用后立即被置为 None 的变量。AOTAutograd 传递会自动生成反向传播:由于 (X + A) 的导数等于 1,链式法则将梯度直接传递。

值得注意的是,在到达 Inductor 并最终生成待执行的内核代码之前,该堆栈的任何部分都不需要“知道”什么是 CuTeDSL 代码。

点击下面的标签页,查看 ALiBi 如何从用户代码演变为最终的 CuTeDSL 内核。

原始 FX IR AOTAutograd CuTeDSL




def alibi_mod(score, b, h, q_idx, kv_idx):
    scale = torch.exp2(-((h + 1) * 8.0 / H))
    bias = (kv_idx - q_idx) * scale
    return score + bias
用户代码:一个实现 ALiBi 的分数修改,这是 flex-attention 项目的动机示例。

Inductor 将逐点 IR 降低为一种 define-by-run 函数,该函数调用 V.ops.<op>,然后交换一个处理程序,为目标后端重新解释这些调用。实际上,这表现为 ops_wrapper(...)OpsWrapper,它们允许你将一元和二元原语映射到一种新语言,而无需更改 IR 本身。对于 CuTeDSL,我们插入了一个 CuTeDSL 处理程序,将这些操作重写为 TensorSSA 表达式,因此算术运算是在寄存器(RMEM)支持的 cute 张量上执行的,且表达式可以进行 CSE(公共子表达式消除)。

我们还为“任意加载”添加了一个专门的加载路径。如果用户编写了一个依赖于某个全局张量的分数/掩码修改,我们会实例化一个 RMEM 片段,并在(可能是间接的)索引处发出加载指令。这使我们能够从 Inductor 的索引表达式桥接到 CuTeDSL 的 TensorSSA。

FlashAttention-4 的 Flex 化

我们对 FA4 进行了两次正交扩展,使其能够作为 FlexAttention 的后端:

  1. 前向和反向传播中的分数修改
  2. 前向和反向传播中的块稀疏迭代

这两个扩展都是用 CuTeDSL 实现的,因此它们可以内联到使 FA4 快速的同一个异步流水线中。

想象 FlashAttention 处理每个 SM 的 KV tile 队列。Flex 化增加了两个钩子:块稀疏性控制哪些 tile 进入队列(跳过空块,标记部分块),而分数/掩码修改则作为 softmax Warp 中的逐点操作应用。

考虑到这种拆分,以下是前向和反向钩子的适配方式。

分数修改

使该项目可行的 CuTeDSL 特性之一是不仅能够向实现传递变长数量的内核参数(然后降低为特定的实例化),还可以直接传递用户可调用对象。回到 FlexAttention 的核心,我们需要在 FlashAttention 算法的精确点注入用户修改的能力。我们基于现有的 FA4 实现构建,该实现本身就是为了支持分数修改而编写的。

在前向传播中,我们将 S tile 从 TMEM 拿回寄存器中,以便我们可以应用修改、计算行最大值/和,并为第二个矩阵乘法生成 P tile。我们定义了一个镜像 FlexAttention score_mod 签名的 CuTeDSL 接口,并且不通过内核传递 N 个可变捕获,而是传递一个 aux_tensors 列表,代表修改中使用的任何全局内存区域。在内核内部,我们将寄存器片段重新解释为 TensorSSA 视图(具有可选的矢量化),并在这些 tile 上内联用户可调用对象。

我们需要在寄存器中保留 S 来计算最大值/和并形成 P tile,因此我们在数据驻留在 RMEM 时应用分数/掩码修改,而不是添加一个单独的阶段。这保持了相同的流水线结构以及 TCGEN 和 SFU 工作之间的重叠。任何来自 aux_tensors 的额外读取都在需要时直接发出,并与消耗 S 的现有阶段一起调度。

反向传播遵循相同的接口形状,使用生成的 score_mod_bwd 可调用对象,但活跃度情况不同。在标准 FA4 中,SdS tile 永远不需要同时处于活跃状态,因此 TMEM 可以跨阶段共享。通过分数修改,反向路径取决于用户导数的需求。

如果梯度仅依赖于 P(或输入梯度),我们保持默认调度并避免在 TMEM 中重叠 S/dS。如果导数依赖于 softmax 前的分数,我们将所需的 S 片段与 PdS 一起保存在寄存器中,并在其贡献被消耗后立即丢弃。TMEM 保留给主累加器使用,代价是这些特定修改会带来更高的寄存器压力。

块稀疏迭代(前向+反向)

FlexAttention 对 FlashAttention 的第二个要求是块稀疏迭代。我们扩展了 FA4 的内核以接受块掩码元数据(要访问的行/列 tile),并从该数据驱动 tile 调度程序,因此内核只触及掩码中存在的 (m, n) tile。我们还使块稀疏路径支持 GQA 打包和广播头部维度。

之前提到的双 tile 乒乓切换带来的一个后果是:Blackwell 上的最小稀疏块大小为 256×128,而 Triton 路径上为 128×128。由于每个 CTA 处理两个 M-tile 以保持流水线满载(q_stage=2),调度程序可以跳过的最小工作单位是 256 行,因此块掩码的粒度必须匹配。

反向传播遍历相同的块掩码,仅为前向传播中存在的 tile 计算梯度。反向内核已经使用了沿行的子 tile 迭代,因此 256 行的约束自然适配。

我们的贡献

这些扩展需要整个 FA4 堆栈的上游更新:

  • 前向和反向传播中的分数修改钩子,包括 SM90/SM100 正确性修复和 GQA 边缘情况处理。
  • Blackwell 和 Hopper 的块稀疏前向和反向路径,以及针对广播掩码修改的 GQA 打包支持。
  • 针对分数/掩码修改路径中连续布局和扩展张量的接口清理。
  • CuTeDSL 升级和 TVM-FFI 启用,以减少 CPU 分发开销。

有了这些基础,让我们看看性能表现。

结果

SDPA 支持的模式

对于像密集(noop)和因果掩码这样的标准注意力模式,我们可以将 FlexAttention 新的 Flash 后端与现有的 Triton 实现和 cuDNN 进行比较。

在 GB200 上,Flash 后端在前向传播上实现了比 Triton 1.6–3.2 倍的加速,在反向传播上实现了 1.85–2.3 倍的加速。对于反向传播,Flash 在某些情况下匹配甚至超越了 cuDNN;前向传播与 cuDNN 仍有一定差距,特别是在因果注意力方面。

你会注意到在前向传播中,Noop 与 cuDNN 匹配得非常接近,而因果注意力则落后更多。这一差距凸显了与 FA4 内置的因果路径相比,块稀疏迭代带来了多少额外开销。

为什么因果注意力会落后(以及如何缩小差距)

经过调查,罪魁祸首之一是工作调度:如果你阅读 FA3 代码,你会看到使用了 最长处理时间优先(LPT)调度,FA4 为内置因果实现了这一点,但 FlexAttention 没有使用。如果我们手动指定 LPT 调度,性能如下:

手动指定 LPT 调度后,前向传播在较短序列上看到了高达 1.6 倍的加速,在较长序列上逐渐减小至约 1.1 倍。反向传播差异极小,因为调度开销被以不同方式摊销了。我们仍然没有完全匹配性能,但正在接近。

LPT 调度在这里有效,因为我们知道特定的稀疏模式是因果的,且该调度对这种情况是最优的。通常,我们无法预先知道模式:块稀疏性可能与数据相关,不同的行拥有不同数量的活动 KV 块。

我们可以依赖 CUDA 以负载均衡的方式启动单个输出 tile,但这样我们会错过通过将 MMA 和加载与尾声(epilogues)重叠以及不重复序言(prologues)而获得的持久调度增益。这正是 集群启动控制 (CLC) 解决的问题。CLC 是一项 Blackwell 功能,支持动态工作调度:不再在启动时静态地将 tile 分配给 SM,工作者可以在运行时即时查询新 tile。当一个 SM 提前完成工作(因为其行处理的块较少)时,它会立即获取下一个可用 tile,而不是闲置。CuTeDSL 4.4 增加了对基于 CLC 的持久调度的支持,这使得 FlexAttention 可以透明地从块稀疏模式的更好工作分配中受益,无需用户指定调度。

FlexAttention 支持的模式

FlexAttention 真正是为了 SDPA 不支持的模式而设计的:ALiBi、文档掩码、滑动窗口和任意用户定义的分数修改。

对于 B200 上的这些仅限 Flex 的模式:

  • ALiBi:前向加速 1.2–2.1 倍,反向加速 1.9–2.9 倍
  • 文档掩码:在较长序列下前向最高加速 2.7 倍,反向最高 3 倍
  • 滑动窗口:前向加速 1.4–2.1 倍,反向加速 1.8–2.2 倍

Hopper (H200) 结果

在 Hopper GPU 上,Flash 在所有序列长度上始终更快。

对于 H200 上的这些仅限 Flex 的模式:

  • ALiBi:前向加速 1.30–1.54 倍,反向加速 1.36–1.65 倍
  • 文档掩码:前向 1.41–1.89 倍,反向 1.48–2.01 倍
  • 滑动窗口:前向 1.45–1.65 倍,反向 1.35–1.52 倍

这些收益即使在较短序列 (2K) 下也存在,并随着序列长度增加而进一步提升。

正确性与基准测试方法

本文中的所有基准测试数据均通过 attention-gym/benchmarks/flex_perf.py 生成。

正确性

我们通过将输出与 FP32 参考进行比较(将 Q/K/V 转换为 FP32,运行注意力,再转回)来验证 Flash 后端。上游测试套件会持续执行这些检查:

  • PyTorch Inductor:在 test/inductor/test_flex_flash.py 中对 score_mod / mask_mod 模式(包括捕获的缓冲区和视图)进行了广泛的矩阵测试,并进行 Flash-vs-Triton 的比较。
  • FlashAttention (CuTe):在 tests/cute/test_mask_mod.py 中对 mask_mod + 块稀疏性进行了压力测试,涵盖了许多 (seqlen_q, seqlen_k) 对,并对比 flex_attention 参考实现了前向和反向正确性验证。

除了单元测试,我们还在真实的训练设置中验证了 Flash 后端:使用 torchtitan 在 64 个 H100 GPU 上训练 Llama 3 70B,序列长度 8192。两次运行在 1000 个训练步骤后都收敛到了约 3.7 的最终损失。

局限性

块大小约束:对于分页注意力(如 vLLM 集成),通常将内核块与页面大小对齐。目前,FA4 路径在 Hopper 上针对 128×128 块、在 Blackwell 上针对 256×128 块(由于 q_stage=2)进行了调优,改变块大小的灵活性有限。随着 FA4 开放更健壮的小块 tile_m/tile_n 选项,我们计划启用此功能。

动态标量:完全支持动态张量形状,并在运行时解析。但是,在 score_modmask_mod 中捕获的标量会被硬编码到编译后的内核中。如果你有一个在调用之间变化的 soft_cap 值,每个唯一值都会触发一次重新编译。

def tanh_softcap(score, b, h, q_idx, kv_idx):
    return soft_cap * tanh(score / soft_cap)

需要梯度的捕获缓冲区的反向传播:目前在 Flash 后端中不支持。例如,可学习的偏置张量。

bias = torch.randn(seq_q, seq_kv, device='cuda', requires_grad=True)
def bias_func(score, b, h, q_idx, kv_idx):
    return score + bias[q_idx, kv_idx]

Triton 后端支持捕获缓冲区的梯度;对于这些情况,请使用 Triton 后端。

带块稀疏性的确定性反向传播:当启用块稀疏性时,Flash 后端的反向传播尚未实现确定性(仅分数修改的工作负载是确定的)。我们正在积极修复此问题。

性能局限

  • 前向传播中对 KV 维度的加载可能会使流水线停滞,特别是对于指针追踪模式(例如具有每 token 元数据的文档掩码),其中 aux-tensor 加载很难与计算重叠。
  • 需要 softmax 前分数的 score_mods 的反向传播在当前平铺下几乎总是会导致寄存器溢出。例如,score**2 的梯度是 2 * score * grad_score,这需要在反向传播期间保持 softmax 前的分数处于活跃状态。TMEM 被主要的注意力累加器完全占用,且当前的块大小很少能在 SMEM 中为 S tile 留出空间,因此它保留在寄存器中并大量溢出,导致明显的减速。

未来工作

我们对 CuTeDSL 和 FA4 集成在缩小研究与生产之间差距感到兴奋。

特别是在 Flash 后端,我们正在努力支持在分数修改中捕获的动态标量,而无需重新编译(例如,在调用之间更改 soft_cap 值)。捕获缓冲区的梯度在可预见的未来仍将依赖 Triton 后端。我们也在探索动态持久调度,以自动改善跨块稀疏模式的工作分配。

虽然本文是关于 FA4 实现的,但 Triton 实现仍支持更广泛的硬件,我们计划继续改进这两个后端。

致谢

这是一次跨仓库的协作。

FlashAttention-4 内核工作(CuTeDSL 实现、调度以及分数/掩码修改和块稀疏性所需的扩展点)位于上游 Dao-AILab/flash-attention,而编译器 + 集成工作(FlexAttention API 行为、Inductor 降低和 CuTeDSL 代码生成)位于上游 pytorch/pytorch

感谢两个仓库的维护者、审阅者和贡献者,感谢 NVIDIA CUTLASS/CuTeDSL 团队构建了使 JIT 风格工作流变得实用的抽象。

  • FlashAttention / FA4(内核 + 扩展点):Tri Dao, Ted Zadouri, Reuben Stern, Markus Hoehnerbach, Jay Shah
  • PyTorch / Inductor(降低 + 代码生成 + 集成):Markus Hoehnerbach
  • CuTeDSL / CUTLASS:Fung Xie

延伸阅读 / 链接