注意力机制作为无处不在的 Transformer 架构的核心层,是大语言模型和长上下文应用的瓶颈。FlashAttention(和 FlashAttention-2)率先提出了一种通过最小化内存读/写来加速 GPU 上注意力计算的方法,现在已被大多数库用于加速 Transformer 的训练和推理。这促成了过去两年 LLM 上下文长度的大幅增加,从 2-4K(GPT-3、OPT)到 128K(GPT-4),甚至 1M(Llama 3)。然而,尽管取得了成功,FlashAttention 尚未利用现代硬件中的新功能,FlashAttention-2 在 H100 GPU 上仅实现了理论最大 FLOPs 的 35% 的利用率。在这篇博文中,我们描述了在 Hopper GPU 上加速注意力计算的三种主要技术:利用 Tensor Cores 和 TMA 的异步性,以(1)通过 warp-specialization 重叠整体计算和数据移动,以及(2)交错块式矩阵乘法和 softmax 操作,以及(3)不连贯处理,利用硬件对 FP8 低精度的支持。
我们很高兴发布 FlashAttention-3,它融合了这些技术。FP16 精度下,它比 FlashAttention-2 快 1.5-2.0 倍,高达 740 TFLOPS,即 H100 理论最大 FLOPS 的 75% 的利用率。在 FP8 精度下,FlashAttention-3 接近 1.2 PFLOPS,误差比基线 FP8 注意力机制小 2.6 倍。
FlashAttention-3 可在此处获取:https://github.com/Dao-AILab/flash-attention
论文
FlashAttention 回顾
FlashAttention 是一种算法,它重新排序注意力计算,并利用平铺和重计算来显著加速计算,并将内存使用量从序列长度的二次方降低到线性。我们使用平铺将输入块从 HBM(GPU 内存)加载到 SRAM(快速缓存),执行关于该块的注意力计算,并在 HBM 中更新输出。通过不将大型中间注意力矩阵写入 HBM,我们减少了内存读/写的量,从而带来了 2-4 倍的实际运行时间加速。
这里我们展示了 FlashAttention 前向传播的图示:通过平铺和 softmax 重缩放,我们按块操作,避免了从 HBM 读取/写入,同时获得了正确的输出,没有近似。
Hopper GPU 上的新硬件特性 - WGMMA、TMA、FP8
虽然 FlashAttention-2 可以在 Ampere (A100) GPU 上实现高达 70% 的理论最大 FLOPS,但它尚未利用 Hopper GPU 上的新特性来最大化性能。我们在这里描述一些新的 Hopper 特有功能,以及它们为何重要。
1. WGMMA(Warpgroup 矩阵乘法-累加)。这项新功能利用了 Hopper 上的新 Tensor Cores,具有比 Ampere 中旧的 mma.sync 指令更高的吞吐量1(图片来自 H100 白皮书)。
2. TMA(Tensor 内存加速器)。这是一个特殊的硬件单元,可以加速全局内存和共享内存之间的数据传输,负责所有索引计算和越界预测。这释放了寄存器,寄存器是增加瓦片大小和效率的宝贵资源。
3. FP8 低精度。这使 Tensor Core 吞吐量翻倍(例如,FP16 为 989 TFLOPS,FP8 为 1978 TFLOPS),但通过使用更少的位来表示浮点数,从而牺牲了精度。
FlashAttention-3 利用了 Hopper 的所有这些新特性,使用了 NVIDIA 的 CUTLASS 库中的强大抽象。
通过重写 FlashAttention 以使用这些新特性,我们已经可以显著加速它(例如,从 FlashAttention-2 FP16 前向传播中的 350 TFLOPS 提高到约 540-570 TFLOPS)。然而,Hopper 上新指令(WGMMA 和 TMA)的异步性质为重叠操作从而提取更高性能提供了额外的算法机会。对于这篇博文,我们将解释两种特定于注意力机制的技术。warp specialization 的通用技术,其中单独的生产者和消费者 warp 执行 TMA 和 WGMMA,在 GEMM 的上下文中已在其他地方得到充分介绍,并且在这里的工作原理相同。
异步性:重叠 GEMM 和 Softmax
为什么要重叠?
注意力机制有 GEMM(Q 和 K 之间以及注意力概率 P 和 V 之间的矩阵乘法)和 softmax 这两个主要操作。我们为什么需要重叠它们?无论如何,大部分 FLOPS 不是都在 GEMM 中吗?只要 GEMM 速度快(例如,使用 WGMMA 指令计算),GPU 不应该 “brrrr” 地运行 吗?
问题在于,在现代加速器上,非矩阵乘法操作比矩阵乘法操作慢得多。诸如指数(对于 softmax)之类的特殊函数的吞吐量甚至比浮点乘加运算更低;它们由多功能单元(一个与浮点乘加或矩阵乘加单元分开的单元)评估。例如,H100 GPU SXM5 具有 989 TFLOPS 的 FP16 矩阵乘法,但特殊函数的吞吐量仅为 3.9 TFLOPS(吞吐量降低 256 倍)2!对于头维度 128,矩阵乘法 FLOPS 比指数运算多 512 倍,这意味着指数运算可能占用矩阵乘法 50% 的时间。对于 FP8 来说,情况更糟,其中矩阵乘法 FLOPS 快两倍,但指数运算 FLOPS 保持相同的速度。理想情况下,我们希望矩阵乘法和 softmax 并行运行。当 Tensor Cores 忙于矩阵乘法时,多功能单元应该在计算指数!
使用乒乓调度的 Warpgroup 间重叠
重叠 GEMM 和 softmax 的第一种也是最简单的方法是什么都不做!warp 调度器已经尝试调度 warp,以便如果某些 warp 被阻塞(例如,等待 GEMM 结果),其他 warp 可以运行。也就是说,warp 调度器为我们免费完成了一些重叠操作。
但是,我们可以通过手动进行一些调度来改进这一点。例如,如果我们有 2 个 warpgroup(标记为 1 和 2 – 每个 warpgroup 是一组 4 个 warp),我们可以使用同步屏障 (bar.sync),以便 warpgroup 1 首先执行其 GEMM(例如,一次迭代的 GEMM1 和下一次迭代的 GEMM0),然后 warpgroup 2 执行其 GEMM,而 warpgroup 1 执行其 softmax,依此类推。下图说明了这种“乒乓”调度,其中相同的颜色表示相同的迭代。
这将使我们能够在另一个 warpgroup 的 GEMM 的阴影下执行 softmax。当然,此图只是一个漫画;实际上,调度并非如此干净。尽管如此,乒乓调度可以将 FP16 注意力机制前向传播从约 570 TFLOPS 提高到 620 TFLOPS(头维度 128,序列长度 8K)。
Warpgroup 内 GEMM 和 Softmax 的重叠
即使在一个 warpgroup 内,我们也可以在运行该 warpgroup 的 GEMM 时运行 softmax 的一部分。下图说明了这一点,其中相同的颜色表示相同的迭代。
这种流水线处理将 FP16 注意力机制前向传播的吞吐量从约 620 TFLOPS 提高到约 640-660 TFLOPS,但代价是更高的寄存器压力。我们需要更多寄存器来保存 GEMM 的累加器以及 softmax 的输入/输出。总的来说,我们发现这项技术提供了有利的权衡。
低精度:通过不连贯处理减少量化误差
LLM 激活可能具有异常值,其幅度远大于其余特征。这些异常值使得量化变得困难,产生更大的量化误差。我们利用不连贯处理,这是量化文献(例如来自QuIP)中使用的一种技术,它将查询和键与随机正交矩阵相乘,以“分散”异常值并减少量化误差。特别是,我们使用 Hadamard 变换(带有随机符号),它可以按注意力头在 O(d log d) 而不是 O(d^2) 时间内完成,其中 d 是头维度。由于 Hadamard 变换受内存带宽限制,因此它可以与先前的操作(例如旋转嵌入,也受内存带宽限制)“免费”融合。
在我们的实验中,Q、K、V 是从标准正态分布生成的,但 0.1% 的条目具有很大的幅度(以模拟异常值),我们发现不连贯处理可以将量化误差减少 2.6 倍。我们在下表中显示了数值误差比较。请参阅论文了解详情。
注意力机制基准测试
我们展示了 FlashAttention-3 的一些结果,并将其与 FlashAttention-2 以及 Triton 和 cuDNN 中的实现(两者都已经使用了 Hopper GPU 的新硬件特性)进行了比较。
对于 FP16,我们看到比 FlashAttention-2 快约 1.6 倍-1.8 倍
对于 FP8,我们可以达到接近 1.2 PFLOPS!
讨论
这篇博文重点介绍了 Hopper GPU 上可用的 FlashAttention 的一些优化。论文中介绍了其他优化(例如,可变长度序列、持久内核和 FP8 的内核内转置)。
我们已经看到,设计利用其运行硬件的算法可以带来显著的效率提升,并解锁新的模型功能,例如长上下文。我们期待未来在 LLM 推理优化方面的工作,以及将我们的技术推广到其他硬件架构。
我们也期待 FlashAttention-3 被集成到 PyTorch 的未来版本中。