博客

FlashAttention-3:通过异步和低精度实现快速准确的注意力机制

Attention(注意力机制)作为无处不在的 Transformer 架构的核心层,是大型语言模型和长上下文应用中的瓶颈。FlashAttention(及 FlashAttention-2)开创了一种通过最小化内存读写来加速 GPU 上注意力计算的方法,目前已被大多数所采用,用于加速 Transformer 的训练和推理。这促成了过去两年中 LLM 上下文长度的巨大增长,从 2-4K(GPT-3、OPT)提升到了 128K(GPT-4),甚至 1M(Llama 3)。然而,尽管取得了成功,FlashAttention 尚未充分利用现代硬件的新特性,FlashAttention-2 在 H100 GPU 上仅达到了理论最大浮点运算能力(FLOPs)的 35% 利用率。在本篇博文中,我们介绍了在 Hopper GPU 上加速注意力计算的三种主要技术:利用张量核心(Tensor Cores)和 TMA(张量内存加速器)的异步特性,通过 (1) 线程束专门化(warp-specialization)重叠计算与数据移动;(2) 交错块级矩阵乘法与 Softmax 操作;以及 (3) 利用硬件支持 FP8 低精度进行不相干处理(incoherent processing)。

我们很高兴发布 FlashAttention-3,它整合了这些技术。在 FP16 下,它的速度比 FlashAttention-2 快 1.5-2.0 倍,最高可达 740 TFLOPS,即 H100 理论最大 FLOPS 的 75% 利用率。在 FP8 下,FlashAttention-3 达到了近 1.2 PFLOPS,且误差比基准 FP8 注意力低 2.6 倍。

FlashAttention-3 获取地址: https://github.com/Dao-AILab/flash-attention
论文

FlashAttention 回顾

FlashAttention 是一种重排注意力计算顺序的算法,它利用分块(tiling)和重计算(recomputation)来显著加速计算,并将内存使用量从序列长度的二次方降低到线性。我们使用分块技术将输入块从 HBM(GPU 内存)加载到 SRAM(快速缓存),对该块执行注意力计算,然后将输出更新回 HBM。通过不将大型中间注意力矩阵写入 HBM,我们减少了内存读写量,从而带来了 2-4 倍的运行时间加速。

此处展示了 FlashAttention 前向传播的示意图:通过分块和 Softmax 重缩放,我们按块进行操作,避免了从 HBM 频繁读写,同时获得了正确的输出,且没有任何近似。

math equations

Hopper GPU 的新硬件特性 – WGMMA, TMA, FP8

虽然 FlashAttention-2 在 Ampere (A100) GPU 上可以达到理论最大 FLOPS 的 70%,但它尚未利用 Hopper GPU 上的新特性来最大化性能。我们在此描述一些 Hopper 特有的新特性,以及它们为何重要。

1. WGMMA (Warpgroup Matrix Multiply-Accumulate):这一新特性利用了 Hopper 上的新张量核心,其吞吐量远高于1 Ampere 架构中的旧版 mma.sync 指令(图片来自 H100 白皮书)。

image from the H100 white paper

2. TMA (Tensor Memory Accelerator):这是一个专用的硬件单元,用于加速全局内存与共享内存之间的数据传输,并处理所有的索引计算和边界谓词判断。这释放了寄存器,对于增加分块大小和提高效率是非常宝贵的资源。

block diagram

3. FP8 低精度:这使张量核心的吞吐量翻倍(例如 FP16 为 989 TFLOPS,FP8 为 1978 TFLOPS),但通过使用更少的位数表示浮点数,在精度上做出了权衡。

6x throughput

FlashAttention-3 利用了 NVIDIA CUTLASS 库的强大抽象能力,充分使用了 Hopper 的所有新特性。

通过重写 FlashAttention 以利用这些新特性,我们已经能够显著提升速度(例如,从 FlashAttention-2 FP16 前向传播的 350 TFLOPS 提升到约 540-570 TFLOPS)。然而,Hopper 上新指令(WGMMA 和 TMA)的异步性质开启了更多的算法优化空间,以重叠操作并获得更高的性能。对于本篇博文,我们将解释两种针对注意力机制的特定技术。通用的线程束专门化(Warp Specialization)技术(即独立的生产者和消费者线程束分别执行 TMA 和 WGMMA)已在 GEMM 的上下文中被广泛讨论,在此同样适用。

异步性:重叠 GEMM 和 Softmax

为什么要重叠?

注意力机制主要由 GEMM(Q 与 K 之间的矩阵乘法,以及注意力概率 P 与 V 之间的矩阵乘法)和 Softmax 这两个核心操作组成。为什么需要重叠它们?难道 GEMM 不占用了大部分 FLOPS 吗?只要 GEMM 足够快(例如使用 WGMMA 指令计算),GPU 不就应该运转飞快了吗

问题在于,在现代加速器上,非矩阵乘法操作比矩阵乘法操作要慢得多。像指数函数(用于 Softmax)这样的特殊函数,其吞吐量远低于浮点乘加运算;它们由多功能单元(multi-function unit)评估,这是一个独立于浮点乘加或矩阵乘加的单元。以 H100 GPU SXM5 为例,它拥有 989 TFLOPS 的 FP16 矩阵乘法能力,但特殊函数的吞吐量仅为 3.9 TFLOPS(慢了 256 倍)2!对于 128 的头维度,矩阵乘法的 FLOPS 是指数函数的 512 倍,这意味着指数计算可能占用相比矩阵乘法 50% 的时间。在 FP8 的情况下情况更糟,矩阵乘法速度提升了两倍,但指数计算速度保持不变。理想情况下,我们希望矩阵乘法和 Softmax 并行运行。当张量核心忙于矩阵乘法时,多功能单元应该在计算指数!

使用乒乓调度(pingpong scheduling)进行线程束组间重叠

重叠 GEMM 和 Softmax 的第一种也是最简单的方法是“顺其自然”!线程束调度器(warp schedulers)本身就会尝试调度线程束,以便当某些线程束被阻塞(例如等待 GEMM 结果)时,其他线程束可以运行。也就是说,调度器已经为我们免费完成了一部分重叠工作。

然而,我们可以通过手动进行一些调度来改进这一点。例如,如果我们有两个线程束组(标记为 1 和 2,每个线程束组由 4 个线程束组成),我们可以使用同步屏障(bar.sync),让线程束组 1 首先执行其 GEMM(例如,当前迭代的 GEMM1 和下一迭代的 GEMM0),然后当线程束组 1 执行 Softmax 时,线程束组 2 执行其 GEMM,依此类推。这种“乒乓”调度如下图所示,其中相同的颜色表示相同的迭代。

block chart

这将允许我们在另一个线程束组执行 GEMM 的同时进行 Softmax。当然,该图只是一个简化的示意;在实践中,调度并不完全如此规整。尽管如此,乒乓调度可以将 FP16 注意力前向传播的速度从约 570 TFLOPS 提升到 620 TFLOPS(头维度 128,序列长度 8K)。

线程束组内 GEMM 和 Softmax 的重叠

即使在同一个线程束组内,我们也可以在 GEMM 运行的同时运行部分 Softmax。这如下图所示,相同的颜色表示相同的迭代。

block chart

这种流水线技术将 FP16 注意力前向传播的吞吐量从约 620 TFLOPS 提高到约 640-660 TFLOPS,代价是更高的寄存器压力。我们需要更多的寄存器来同时保存 GEMM 的累加器以及 Softmax 的输入/输出。总体而言,我们认为这项技术提供了一个有利的折中方案。

低精度:通过不相干处理减少量化误差

LLM 的激活值可能包含远大于其他特征的异常值。这些异常值使得量化变得困难,产生更大的量化误差。我们利用“不相干处理”(incoherent processing),这是一种量化文献中(例如 QuIP)使用的技术,通过将 Query 和 Key 与随机正交矩阵相乘来“分散”异常值,从而减少量化误差。具体而言,我们使用哈达玛变换(Hadamard transform,带有随机符号),它可以在每个注意力头中以 O(d log d) 的时间复杂度完成,而不是 O(d^2),其中 d 是头维度。由于哈达玛变换受内存带宽限制,它可以与旋转嵌入(rotary embedding,同样受内存带宽限制)等先前操作融合在一起,从而实现“零成本”计算。

在实验中,我们从标准正态分布生成 Q、K、V,并让 0.1% 的条目具有较大数量级(以模拟异常值),结果发现不相干处理可以将量化误差降低 2.6 倍。我们在下表中展示了数值误差的比较。详情请参阅论文。

text diagram

注意力性能测试

我们展示了 FlashAttention-3 的一些结果,并将其与 FlashAttention-2,以及 Triton 和 cuDNN 中的实现(两者都已经使用了 Hopper GPU 的新硬件特性)进行了比较。

对于 FP16,我们看到比 FlashAttention-2 提升了约 1.6 倍-1.8 倍的速度。

speed charts
speed charts

对于 FP8,我们可以达到近 1.2 PFLOPS!

speed charts

讨论

本篇博文强调了 Hopper GPU 上针对 FlashAttention 的部分优化。其他优化(如变长序列、持久化内核以及针对 FP8 的内核内转置)在论文中都有介绍。

我们已经看到,设计能够利用所运行硬件特性的算法,可以带来显著的效率提升,并解锁如长上下文等新的模型能力。我们期待未来在 LLM 推理优化方面的工作,以及将这些技术推广到其他硬件架构中。

我们也期待 FlashAttention-3 能被集成到 PyTorch 的未来版本中。

注释

  1. 如果不使用 wgmma 指令,旧的 mma.sync 指令只能达到 Hopper 张量核心峰值吞吐量的约 2/3:https://arxiv.org/abs/2402.13499v1 
  2. CUDA 编程指南指出,特殊函数的吞吐量为每个流多处理器(SM)每个时钟周期 16 次操作。我们将 16 乘以 132 个 SM 和 1830 MHz(计算 989 TFLOPS 的 FP16 矩阵乘法所使用的时钟速度)得到 3.9 TFLOPS