作者:Jay Shah 和 Ganesh Bikshandi (Colfax Research),Ying Zhang (Meta),Vijay Thakkar 和 Pradeep Ramani (NVIDIA),Tri Dao (TogetherAI 和 Princeton University)

注意力机制作为无处不在的 Transformer 架构的核心层,是大型语言模型和长上下文应用中的瓶颈。FlashAttention(和 FlashAttention-2)率先采用了一种方法,通过最大限度地减少内存读/写来加速 GPU 上的注意力机制,现在已被大多数用于加速 Transformer 训练和推理。这使得过去两年中 LLM 上下文长度大幅增加,从 2-4K(GPT-3、OPT)增加到 128K(GPT-4),甚至 1M(Llama 3)。然而,尽管 FlashAttention 取得了成功,但它尚未利用现代硬件的新功能,FlashAttention-2 在 H100 GPU 上的理论最大 FLOPS 利用率仅达到 35%。在这篇博文中,我们介绍了三种主要的加速 Hopper GPU 上的注意力机制的技术:利用 Tensor Cores 和 TMA 的异步性,通过 Warp 特化 (warp-specialization) (1) 重叠整体计算和数据移动,并 (2) 交错块状矩阵乘法和 softmax 操作;以及 (3) 利用硬件对 FP8 低精度的支持进行不连贯处理 (incoherent processing)。

我们很高兴发布 FlashAttention-3,它融合了这些技术。使用 FP16 时,它比 FlashAttention-2 快 1.5-2.0 倍,高达 740 TFLOPS,即 H100 理论最大 FLOPS 的 75% 利用率。使用 FP8 时,FlashAttention-3 接近 1.2 PFLOPS,误差比基准 FP8 注意力机制小 2.6 倍。

FlashAttention-3 可通过以下链接获取:https://github.com/Dao-AILab/flash-attention
论文

FlashAttention 回顾

FlashAttention 是一种算法,它重新排列注意力计算并利用分块 (tiling) 和重计算 (recomputation) 来显著加速并减少内存使用,从序列长度的平方级降至线性级。我们使用分块将输入块从 HBM(GPU 内存)加载到 SRAM(快速缓存),对该块执行注意力计算,并更新 HBM 中的输出。通过不将大型中间注意力矩阵写入 HBM,我们减少了内存读/写次数,从而带来了 2-4 倍的实际运行时间加速。

此处显示了 FlashAttention 前向传播的示意图:通过分块和 softmax 重新缩放,我们按块进行操作并避免从 HBM 读/写,同时在没有任何近似的情况下获得正确输出。

math equations

Hopper GPU 上的新硬件特性 - WGMMA、TMA、FP8

虽然 FlashAttention-2 在 Ampere (A100) GPU 上可以达到理论最大 FLOPS 的 70%,但它尚未利用 Hopper GPU 上的新特性来最大化性能。我们在此介绍一些新的 Hopper 特有特性及其重要性。

1. WGMMA (Warpgroup Matrix Multiply-Accumulate)。这项新特性利用了 Hopper 上新的 Tensor Cores,其吞吐量1远高于 Ampere 中旧的 mma.sync 指令(图片来自H100 白皮书)

image from the H100 white paper

2. TMA (Tensor Memory Accelerator)。这是一个特殊的硬件单元,用于加速全局内存和共享内存之间的数据传输,负责处理所有索引计算和越界预测。这释放了寄存器,寄存器是增加瓦片大小和效率的宝贵资源。

block diagram

3. 使用 FP8 实现低精度。这将 Tensor Core 吞吐量加倍(例如,FP16 为 989 TFLOPS,FP8 为 1978 TFLOPS),但通过使用更少的位表示浮点数而牺牲精度。

6x throughput

FlashAttention-3 利用了 Hopper 的所有这些新特性,使用了来自 NVIDIA 的 CUTLASS 库的强大抽象。

通过重写 FlashAttention 以使用这些新特性,我们已经可以显著加速它(例如,FlashAttention-2 FP16 前向传播从 350 TFLOPS 提升到大约 540-570 TFLOPS)。然而,Hopper 上新指令(WGMMA 和 TMA)的异步特性为重叠操作并由此提取更大性能提供了额外的算法机会。在这篇博文中,我们将解释两种特定于注意力机制的技术。Warp 特化 (warp specialization) 的通用技术,即使用单独的生产者和消费者 warps 进行 TMA 和 WGMMA,在 GEMM 的上下文中已在其他地方得到很好的介绍,并且在这里也适用。

异步性:重叠 GEMM 和 Softmax

为何重叠?

注意力机制将 GEMM(Q 和 K 之间以及注意力概率 P 和 V 之间的矩阵乘法)和 softmax 作为其两个主要操作。我们为什么需要重叠它们?GEMM 不是占了大部分 FLOPS 吗?只要 GEMM 够快(例如,使用 WGMMA 指令计算),GPU 不应该飞速运转吗?

问题在于,非矩阵乘法操作在现代加速器上比矩阵乘法操作慢得多。指数函数(用于 softmax)等特殊函数的吞吐量甚至低于浮点乘加;它们由多功能单元计算,该单元与浮点乘加或矩阵乘加单元是分开的。例如,H100 GPU SXM5 的 FP16 矩阵乘法吞吐量为 989 TFLOPS,但特殊函数仅为 3.9 TFLOPS(吞吐量慢 256 倍)2!对于头维度为 128 的情况,矩阵乘法 FLOPS 是指数函数 FLOPS 的 512 倍,这意味着指数函数可能耗费与矩阵乘法相当的时间(50%)。FP8 的情况更糟,因为矩阵乘法 FLOPS 快了两倍,但指数函数 FLOPS 速度不变。理想情况下,我们希望矩阵乘法和 softmax 并行运行。当 Tensor Cores 忙于矩阵乘法时,多功能单元应该计算指数!

跨 warpgroup 重叠和乒乓调度

重叠 GEMM 和 softmax 的第一个也是最简单的方法是什么都不做!Warp 调度器已经尝试调度 warps,以便当某些 warps 被阻塞(例如,等待 GEMM 结果)时,其他 warps 可以运行。也就是说,warp 调度器免费为我们完成了一部分重叠。

然而,我们可以通过手动进行一些调度来改进。例如,如果我们有两个 warpgroups(标记为 1 和 2——每个 warpgroup 是一组 4 个 warps),我们可以使用同步 barrier (bar.sync) 来让 warpgroup 1 先完成其 GEMMs(例如,一个迭代的 GEMM1 和下一个迭代的 GEMM0),然后 warpgroup 2 在 warpgroup 1 执行其 softmax 时执行其 GEMMs,依此类推。下图说明了这种“乒乓”调度,其中相同的颜色表示相同的迭代。

block chart

这将使我们能够在另一个 warpgroup 的 GEMMs 阴影下执行 softmax。当然,此图只是一个示意图;实际上调度并非如此清晰。尽管如此,乒乓调度可以将 FP16 注意力前向传播从大约 570 TFLOPS 提高到 620 TFLOPS(头维度 128,序列长度 8K)。

Warpgroup 内的 GEMM 和 Softmax 重叠

即使在一个 warpgroup 内,我们也可以在该 warpgroup 的 GEMMs 运行时,让一部分 softmax 也运行。此图说明了这一点,其中相同的颜色表示相同的迭代。

block chart

这种流水线处理将 FP16 注意力前向传播的吞吐量从大约 620 TFLOPS 提高到大约 640-660 TFLOPS,代价是更高的寄存器压力。我们需要更多寄存器来容纳 GEMMs 的累加器和 softmax 的输入/输出。总的来说,我们发现这项技术提供了有利的权衡。

低精度:使用不连贯处理减少量化误差

LLM 激活可能存在离群值,其幅度远大于其余特征。这些离群值使得量化变得困难,产生更大的量化误差。我们利用不连贯处理 (incoherent processing),这是量化文献中使用的一种技术(例如来自QuIP),它通过将查询和键与随机正交矩阵相乘来“分散”离群值并减少量化误差。特别是,我们使用 Hadamard 变换(带随机符号),这可以在 O(d log d) 而非 O(d^2) 时间内按注意力头执行,其中 d 是头维度。由于 Hadamard 变换是内存带宽受限的,因此它可以与旋转嵌入 (rotary embedding) 等先前的操作(也是内存带宽受限的)“免费”融合。

在我们的实验中,Q、K、V 从标准正态分布生成,但 0.1% 的条目具有较大幅度(模拟离群值),我们发现不连贯处理可以将量化误差减少 2.6 倍。下表显示了数值误差比较。详细信息请参见论文。

text diagram

注意力机制基准测试

我们展示了 FlashAttention-3 的一些结果,并将其与 FlashAttention-2 以及 Triton 和 cuDNN 中的实现进行比较(这二者都已经使用了 Hopper GPU 的新硬件特性)。

对于 FP16,我们看到比 FlashAttention-2 快约 1.6 倍至 1.8 倍。

speed charts

speed charts

对于 FP8,我们可以达到接近 1.2 PFLOPS!

speed charts

讨论

这篇博文强调了 FlashAttention 在 Hopper GPU 上的一些优化。其他优化(例如可变长度序列、持久化内核以及 FP8 的内核内转置)在论文中进行了介绍。

我们已经看到,设计利用其运行硬件的算法可以带来显著的效率提升,并解锁新的模型能力,例如长上下文。我们期待未来在 LLM 推理优化方面的工作,以及将我们的技术推广到其他硬件架构。

我们也期待 FlashAttention-3 在未来版本的 PyTorch 中集成。

注释

  1. 没有 wgmma 指令,旧的 mma.sync 指令只能达到 Hopper Tensor Cores 峰值吞吐量的约 ⅔:https://arxiv.org/abs/2402.13499v1 

  2. CUDA 编程指南指出,特殊函数的吞吐量为每个流式多处理器 (SM) 每个时钟周期 16 次操作。我们将 16 乘以 132 个 SM 和 1830 Mhz(用于计算 FP16 矩阵乘法的 989 TFLOPS 的时钟速度)得到 3.9 TFLOPS