跳转到主要内容
博客

FlashAttention-3:通过异步和低精度实现快速准确的注意力机制

作为无处不在的 Transformer 架构的核心层,注意力(Attention)机制是大型语言模型和长上下文应用的瓶颈。FlashAttention(以及 FlashAttention-2)开创了一种通过最小化内存读写来加速 GPU 上注意力计算的方法,现在大多数都使用它来加速 Transformer 的训练和推理。这使得过去两年中 LLM 的上下文长度大幅增加,从 2-4K(GPT-3, OPT)增加到 128K(GPT-4),甚至 1M(Llama 3)。然而,尽管取得了成功,FlashAttention 尚未利用现代硬件的新功能,FlashAttention-2 在 H100 GPU 上仅达到理论最高 FLOPs 的 35%。在这篇博文中,我们描述了在 Hopper GPU 上加速注意力计算的三种主要技术:利用 Tensor Core 和 TMA 的异步性来 (1) 通过 warp 专用化(warp-specialization)重叠整体计算和数据移动,(2) 交错执行块级(block-wise)矩阵乘法和 softmax 操作,以及 (3) 利用硬件对 FP8 低精度的支持进行非相干处理(incoherent processing)。

我们很高兴发布集成了这些技术的 FlashAttention-3。在 FP16 精度下,它比 FlashAttention-2 快 1.5-2.0 倍,性能高达 740 TFLOPS,即 H100 理论最高 FLOPS 的 75%。在 FP8 精度下,FlashAttention-3 的性能接近 1.2 PFLOPS,同时误差比基准 FP8 注意力小 2.6 倍。

FlashAttention-3 可在此处获取:https://github.com/Dao-AILab/flash-attention
论文

FlashAttention 回顾

FlashAttention 是一种通过重排注意力计算顺序、并利用分块(tiling)和重计算来显著加速计算并减少内存使用的算法,将内存使用量从与序列长度的二次方关系降低到线性关系。我们使用分块技术将输入块从 HBM(GPU 内存)加载到 SRAM(高速缓存)中,对该块执行注意力计算,并更新 HBM 中的输出。通过不将庞大的中间注意力矩阵写入 HBM,我们减少了内存读写的次数,从而将实际运行时间缩短了 2-4 倍。

这里我们展示了 FlashAttention 前向传播的示意图:通过分块和 softmax 重缩放,我们按块进行操作,避免了从 HBM 读写,同时在没有近似的情况下获得了正确的输出。

math equations

Hopper GPU 的新硬件特性 – WGMMA、TMA、FP8

虽然 FlashAttention-2 在 Ampere (A100) GPU 上可以达到理论最高 FLOPS 的 70%,但它尚未利用 Hopper GPU 的新特性来最大化性能。我们在此描述一些 Hopper 特有的新功能及其重要性。

1. WGMMA (Warpgroup Matrix Multiply-Accumulate,Warp 组矩阵乘法累加)。这一新特性利用了 Hopper 上新的 Tensor Core,其吞吐量1 远高于 Ampere 中的旧 `mma.sync` 指令(图片来自H100 白皮书)。

image from the H100 white paper

2. TMA (Tensor Memory Accelerator,张量内存加速器)。这是一个特殊的硬件单元,用于加速全局内存和共享内存之间的数据传输,并处理所有索引计算和越界预测。这释放了寄存器,而寄存器是增加分块大小和效率的宝贵资源。

block diagram

3. FP8 低精度。这使得 Tensor Core 的吞吐量翻倍(例如,FP16 为 989 TFLOPS,FP8 为 1978 TFLOPS),但通过使用更少的比特来表示浮点数,牺牲了精度。

6x throughput

FlashAttention-3 利用了 Hopper 的所有这些新特性,使用了来自 NVIDIA CUTLASS 库的强大抽象。

通过重写 FlashAttention 以使用这些新特性,我们已经可以显著加速它(例如,FP16 前向传播从 FlashAttention-2 的 350 TFLOPS 提升到约 540-570 TFLOPS)。然而,Hopper 上新指令(WGMMA 和 TMA)的异步特性为重叠操作带来了额外的算法机会,从而可以获得更高的性能。在这篇博文中,我们将解释两种针对注意力机制的此类技术。通用的 warp 专用化技术,即使用独立的生产者和消费者 warp 分别执行 TMA 和 WGMMA,已在 GEMM 的背景下被其他地方充分介绍,在这里的原理相同。

异步性:重叠 GEMM 和 Softmax

为什么要重叠?

注意力机制的两个主要操作是 GEMM(Q 和 K 之间以及注意力概率 P 和 V 之间的矩阵乘法)和 softmax。我们为什么需要重叠它们呢?难道大部分的 FLOPs 不都在 GEMM 中吗?只要 GEMM 速度够快(例如,使用 WGMMA 指令计算),GPU 不就应该全速运转了吗

问题在于,在现代加速器上,非矩阵乘法操作远比矩阵乘法操作慢。诸如指数函数(用于 softmax)之类的特殊函数的吞吐量甚至低于浮点乘加运算;它们由多功能单元(multi-function unit)评估,这是一个独立于浮点乘加或矩阵乘加的单元。例如,H100 GPU SXM5 具有 989 TFLOPS 的 FP16 矩阵乘法性能,但特殊函数的性能仅为 3.9 TFLOPS(吞吐量低 256 倍)2!对于头维度为 128 的情况,矩阵乘法的 FLOPs 是指数函数 FLOPs 的 512 倍,这意味着指数函数所花费的时间可能占到矩阵乘法时间的 50%。对于 FP8 来说,情况更糟,矩阵乘法 FLOPs 速度翻倍,而指数函数 FLOPs 的速度保持不变。理想情况下,我们希望矩阵乘法和 softmax 能够并行操作。当 Tensor Core 忙于矩阵乘法时,多功能单元应该在计算指数函数!

使用乒乓调度实现 Warp 组间重叠

重叠 GEMM 和 softmax 的第一种也是最简单的方法是完全不采取任何措施!Warp 调度器已经会尝试调度 warp,以便在某些 warp 被阻塞时(例如,等待 GEMM 结果),其他 warp 可以运行。也就是说,warp 调度器免费为我们完成了部分重叠工作。

然而,我们可以通过手动进行一些调度来改进这一点。例如,如果我们有两个 warp 组(标记为 1 和 2——每个 warp 组是 4 个 warp 的集合),我们可以使用同步屏障(bar.sync),让 warp 组 1 先执行其 GEMM(例如,一次迭代的 GEMM1 和下一次迭代的 GEMM0),然后当 warp 组 1 执行其 softmax 时,warp 组 2 执行其 GEMM,依此类推。这种“乒乓”调度如下图所示,其中相同颜色表示同一次迭代。

block chart

这将使我们能够在另一个 warp 组执行 GEMM 的“阴影”下执行 softmax。当然,这张图只是一个示意图;在实践中,调度并非如此清晰。尽管如此,乒乓调度可以将 FP16 注意力前向传播的性能从约 570 TFLOPS 提升到 620 TFLOPS(头维度 128,序列长度 8K)。

Warp 组内 GEMM 和 Softmax 的重叠

即使在一个 warp 组内部,我们也可以在该 warp 组的 GEMM 运行时运行部分 softmax。这在这张图中得到了说明,其中相同颜色表示同一次迭代。

block chart

这种流水线操作将 FP16 注意力前向传播的吞吐量从约 620 TFLOPS 提升到约 640-660 TFLOPS,但代价是更高的寄存器压力。我们需要更多的寄存器来同时保存 GEMM 的累加器和 softmax 的输入/输出。总的来说,我们发现这种技术提供了一个有利的权衡。

低精度:通过非相干处理减少量化误差

LLM 的激活值中可能存在异常值,其量级远大于其余特征。这些异常值使得量化变得困难,会产生大得多的量化误差。我们利用了非相干处理(incoherent processing)技术,这是一种在量化文献中使用的技术(例如来自 QuIP),它将查询(query)和键(key)与一个随机正交矩阵相乘,以“分散”异常值并减少量化误差。特别地,我们使用哈达玛变换(Hadamard transform)(带有随机符号),这可以对每个注意力头在 O(d log d) 时间内完成,而不是 O(d^2) 时间,其中 d 是头维度。由于哈达玛变换是内存带宽受限的,它可以与之前的操作(如同样是内存带宽受限的旋转位置编码)“免费”地融合在一起。

在我们的实验中,Q、K、V 来自标准正态分布,但其中 0.1% 的条目具有很大的量级(以模拟异常值),我们发现非相干处理可以将量化误差减少 2.6 倍。我们在下表中显示了数值误差的比较。详情请参阅论文。

text diagram

注意力基准测试

我们展示了 FlashAttention-3 的一些结果,并将其与 FlashAttention-2 以及 Triton 和 cuDNN 中的实现进行了比较(后两者都已使用 Hopper GPU 的新硬件特性)。

对于 FP16,我们看到比 FlashAttention-2 快约 1.6-1.8 倍

speed charts
speed charts

对于 FP8,我们可以达到接近 1.2 PFLOPS!

speed charts

讨论

这篇博文重点介绍了一些在 Hopper GPU 上可用的 FlashAttention 优化。其他优化(例如,可变长度序列、持久化内核以及用于 FP8 的核内转置)在论文中有详细介绍。

我们已经看到,设计能够利用其运行硬件的算法可以带来显著的效率提升,并解锁新的模型能力,如长上下文。我们期待未来在 LLM 推理优化方面的工作,以及将我们的技术推广到其他硬件架构。

我们也期待 FlashAttention-3 在未来的 PyTorch 版本中被集成。

注释

  1. 如果没有 wgmma 指令,旧的 mma.sync 指令只能达到 Hopper Tensor Core 峰值吞吐量的大约 ⅔:https://arxiv.org/abs/2402.13499v1
  2. CUDA 编程指南指出,特殊函数的吞吐量为每个流式多处理器(SM)每个时钟周期 16 次操作。我们将 16 乘以 132 个 SM 和 1830 Mhz(用于计算 989 TFLOPS FP16 矩阵乘法性能的时钟速度),得到 3.9 TFLOPS