在本篇博文中,我们介绍了广义点积注意力(GDPA)的内核设计。GDPA 是标准点积注意力(SDPA)的一种变体,它用不同的激活函数取代了 Softmax 操作,从而支持各种交互用例。该设计已被用于 InterFormer [2] 和 Kunlun [1] 的注意力模块中,这两者均部署在 Meta 的生成式广告模型 (GEM) [3] 中,这是 Meta 规模最大的推荐系统训练基础模型。基于这些模型中观察到的真实注意力负载,我们在 Tri Dao 的 Flash Attention 4 内核 (FA4) 基础上,引入了一系列针对大批次训练、可变序列长度和非 Softmax 激活优化的工作负载驱动型改进方案。
在 Meta 集群中部署的 750W 功耗限制的 NVIDIA B200 GPU 上进行评估,我们优化的 GDPA 内核在正向传递中实现了高达 2 倍的加速,达到 1,145 BF16 Tensor Core TFLOPs(约 97% 的张量核心利用率);在反向传递中实现了高达 1.6 倍的加速,达到 702 BF16 TFLOPs,相比原始的基于 Triton 的实现。除了注意力机制,相同的设计原则也可以推广到其他处理现实中不规则形状数据的内核上。当应用于整个模型时,这些定制内核使训练吞吐量提高了 30% 以上,证明了生产驱动型内核设计的有效性。总体而言,在某些实际生产流量设置下,相比 FA4(一种 SOTA 注意力内核),我们的方法在正向传递中实现了高达 3.5 倍的加速,在反向传递中实现了 1.6 倍的加速。
代码仓库:https://github.com/facebookresearch/ads_model_kernel_library/blob/main/gdpa/README.md
† 工作完成时在 Meta 任职
‡ 工作完成时在普林斯顿大学任职
1. 推荐系统训练负载中的 GDPA
广义点积注意力(GDPA)是一种广泛使用的交互模式,特别是在推荐系统(RecSys)模型中。它将标准点积注意力推广到了 Softmax 公式之外。GDPA 不受限于基于概率的归一化,而是允许通过自定义的逐元素激活函数来转换注意力分数。这种设计已被多个生产级推荐系统架构采用。例如,Kunlun [1] 在其 PFFN 模块中应用了 GELU 激活,而 Meta 的另一个大型推荐模型 HSTU [4] 则利用 SiLU 激活,以在序列建模中更好地保持分数幅值。

图 1. (a) GEM 中的自注意力。(b) GEM 中的 PMA。(c) InterFormer 中的 PFFN。通过 GDPA 内核,我们将所有注意力内核统一为一种相似的实现。
GDPA 涵盖了生产推荐系统中广泛使用的各类注意力模块(如自注意力、PMA 和 PFFN),它们共享如图 1 所示的两次矩阵乘法及中间可选激活的共同模式。通过在 GDPA 公式下统一这些模块,我们可以设计一个单一的高性能内核,以在各种实际训练负载中实现高效优化。在本文中,我们以 GELU 为例来阐述优化过程。
2. 实际训练负载中的挑战
起初,我们观察到改编自最新 Triton 模板 的原始 GDPA 训练内核在实际生产负载下的表现不佳。如图 2 所示,我们将实际生产内核性能与 CUTLASS FMHA 基准进行了对比,后者是我们针对生产形状测试出的最快 FlashAttention 内核。实际生产内核使用真实数据进行评估,而基准测试则使用具有相同最大序列长度的合成数据。关键差异在于数据分布:真实数据由用户行为驱动,不遵循固定分布,而合成数据则使用正态分布生成。在 NVIDIA B200 上,真实运行环境下正向传递存在 2.6 倍的性能差距,反向传递存在 1.6 倍的差距,最坏情况下的差距甚至高达 4 倍。

图 2. 实际生产环境与基准测试的内核性能对比。正向(左)比基准测试低 2.6 倍,反向(右)低 1.6 倍,最坏情况下差距达 4 倍。
我们的分析表明,这种差距源于面向 LLM(大语言模型)的内核设计与生产级推荐系统工作负载之间的根本不匹配。实际流量主要由短序列和非对称序列、大批次、锯齿状(jagged)输入组成,这显著降低了流水线占用率并限制了计算与内存的重叠。这些观察结果促使我们重新设计内核,明确考虑高度动态的短序列输入,具体讨论见下文。
3. GDPA 训练内核的设计与优化
我们的目标是针对实际生产流量优化内核,并将性能推向硬件极限。以 FA4 内核(在我们评估中,这是 LLM 风格形状下 GPU 上的最优选择)为起点,我们通过重新思考流水线、调度和核心计算,为 GDPA 训练负载重新设计了内核。
3.1 为 GDPA 训练重构流水线
在 FlashAttention 内核 [5] 中,线程束专用化(warp specialization)主要由 Softmax 计算驱动,由单独的线程束组负责 Softmax 评估、校正和尾声(epilogue)阶段。在 GDPA 中,Softmax 计算被逐元素激活取代,且消除了 Softmax 校正阶段。如图 3 所示,我们简化了线程束专用流水线,完全删除了校正阶段,并将 TMEM 到 SMEM 的尾声加载任务合并到激活阶段,而不是分配给专门的线程束组。这种设计减少了 4 个总线程束,并为剩余线程束释放了寄存器资源;特别是激活线程束每个线程束额外获得了 16 个寄存器,如图 3 所示。

图 3. 为 GDPA 正向内核重构 FlashAttention 流水线。校正阶段被删除,激活线程束在尾声中接管了其角色,包括将结果从 TMEM 写入 SMEM。通过消除 4 个线程束,每个激活线程束多出了 16 个寄存器。
另一个主要瓶颈来自生产负载中常见的极短 K/V 序列。在持久化内核中,注意力内核通常被表示为双重嵌套循环,其中内循环遍历 K/V 维度。原始的内循环设计假设有足够的迭代次数来分摊流水线设置成本(如下面的 Alg.1 所示),但当内循环仅运行几次迭代时,该假设不再成立(例如,当 kv 序列长度为 128 或 256,块大小为 128 时,内循环仅运行一次或两次)。在这种情况下,内循环上的软件流水线效率会大幅下降。

Alg. 1-2. 算法 1 在传统的 MMA 线程束流水线中使用内循环软件流水线 (SWP),而算法 2 应用外循环 SWP 以更好地处理短 K/V 序列。
为了解决这个问题,我们将内循环展平到外循环中,并在外循环级别应用软件流水线 (SWP),如 Alg. 2 所示。思路很简单:在序言(prologue)中,我们在执行当前迭代的第一阶段 p·v 的同时计算前一次迭代的 qk。在尾声中,我们完成剩余的 p·v 阶段,将结果从张量内存写入共享内存,并尽快释放不再需要的中间存储。这实现了跨迭代的重叠,如图 4 所示有效地对多个迭代进行了流水线处理。实际上,当 kv 长度较短时,这种优化带来了约 10% 的提升。

图 4. 循环展平后,MMA 和激活线程束的重叠更好,相比图 3,减少了内循环中存在的流水线气泡。
3.2 针对锯齿状张量的新型软件负载均衡算法
锯齿状序列在生产负载中很常见,严重限制了 GPU 利用率,这是现有 Flash Attention 内核在生产模型中表现不佳的关键原因。大多数 FlashAttention 风格的持久化调度程序是为密集输入构建的:它们枚举跨批次、头和 M 维度的块,并隐式假设每个块承载相似的工作量。
在实践中,序列长度是动态的,且在内核启动时通常未知,因此锯齿状负载通过假设最大长度并在运行时检查块有效性来调度。这导致调度了许多空块,而有效块由于 K/V 长度不同,仍可能具有高度可变的工作负载,导致持久化 SM 不平衡。简单的启发式算法(如按长度对序列进行重新排序)不足以解决问题,因为调度程序仍然无法感知块的有效性和工作负载异构性。
为了解决这些问题,我们通过利用现有的序列长度元数据在 CPU 上预先计算有效块,将负载均衡移至软件层面。在 GPU 内核受限的训练负载中,这种 CPU 预处理带来的开销可以忽略不计,因为它在每次迭代中只执行一次,并分摊到各层中。持久化内核随后仅在真实块上调度工作,消除了空块执行并减少了不平衡。
为了更好地理解我们的块调度器设计,我们将该过程分为两步说明。首先,我们专注于沿 Q 维度平衡工作负载。具体来说,我们消除了由锯齿状序列引起的所有空 Q 块,并以循环方式将剩余的有效 Q 块分配给 SM,而不是遵循固定的批次-头-M 排序。如图 5 所示,这消除了空操作执行,并在第一步中部分平衡了跨 SM 的工作负载。
然而,对于交叉注意力负载,仅靠 Q 长度平衡是不够的,因为每个 Q 块的工作量仍随动态 K/V 长度变化,如图 5 步骤 2 所示。因此,我们进一步按 K/V 块计数对块进行排序,并在 SM 之间应用轻量级的锯齿形分配模式,即以从长到短和从短到长的交替波次调度块,从而平滑残留的不平衡。这种组合策略将工作负载偏差从很大的范围(例如,最大与最小 SM 之间 12 vs 2 个块)降低到了更紧密的分布(例如 5 vs 4 个块)。

图 5. 由于长度可变的输入,初始 SM 工作负载不平衡。我们首先通过移除空块并以循环方式分配有效块来平衡 Q 工作负载,然后使用锯齿形调度策略进一步平衡 K/V 工作负载。
总之,这两个步骤构成了我们的锯齿形块调度算法,总结于 Alg. 3 中。由于两个步骤都是可向量化的,CPU 上的调度开销极小;否则,顺序调度策略会更合适。该算法为 SM 利用率带来了稳定的提升。

Alg. 3. 使用预计算的有效块和锯齿形分配来平衡持久化内核的锯齿状张量的软件级块调度
我们还将类似的软件级块调度应用于反向传递,关键区别在于 Q 和 K/V 的角色互换,以匹配反向循环结构,因为 Q 序列长度是在内循环中迭代的。
3.3 数学优化
现代注意力机制和 FFN 内核计算密度越来越高,但 GPU 计算资源并不均衡。特别是执行超越函数(如 exp, tanh)的特殊功能单元(SFU)比 CUDA 核心要稀缺得多。严重依赖 SFU 的内核很容易受到 SFU 限制,即使在张量核心未充分利用的情况下也是如此。在实践中,这限制了进一步的扩展。例如,我们的基准 GELU 实现依赖于 tanh.approx.ftz,尽管进行了进一步优化,但仍受限于 SFU。
受 FA4 内核 的启发,该内核通过使用 exp 的软件近似在 SFU 和 CUDA 核心之间重新分配超越计算来缓解 SFU 瓶颈,我们也对 GELU 应用了类似的想法。我们不是仅近似 tanh 分量,而是使用仅包含 ALU 的泰勒展开来近似整个 GELU 函数,如公式 1 所示。我们发现这种方法比模拟单个 SFU 操作表现更好,因为它能够吸收应用于 tanh 输入和输出的多项式变换。

公式 1. GELU 近似对比。上式显示了标准的基于 tanh 的 GELU 近似,使用 1 个 SFU 指令和 8 个 ALU 指令。下式显示了仅包含 ALU 的 GELU 泰勒展开,使用 9 个 ALU 指令。
如图 6 所示,这种近似确实在输入幅度变大时引入了精度限制。然而,在我们的生产模型中,每个注意力块前面都有 QK-norm。模型还采用额外的归一化层和全局裁剪来实现其他目的,共同限制了输入分布,确保泰勒近似保持准确。因此,该方法在实践中效果良好,并有效地缓解了我们生产负载中的 SFU 瓶颈。此外,这种方法可以应用于正向和反向内核以及其他使用 SFU 的激活函数。

图 6. 标准基于 tanh 的 GELU 近似与 6 阶泰勒近似的对比。泰勒展开仅在有限的输入范围内准确;在我们的生产模型中,RMSNorm 将激活限制在此范围内。
3.4 反向传递的持久化调度
持久化调度技术在隐藏延迟方面非常有效。在我们的案例中,它在反向传递中带来了 5% 的 FLOPS 提升。为了在反向传递中实现持久化调度,一个关键的变化是添加额外的同步来编排异步生产者和消费者线程束之间的内存队列。
内存队列设计为多次计算重用了相同的张量内存和共享内存(由于张量内存和共享内存空间有限)。在图 7 中,我们展示了内存重用的细节。与我们额外同步相关的是,在第 4 阶段,sQ 的共享内存空间被 sdK 重用,生产者是 Q_reduce 线程束,消费者是 TMA,sdO 的内存也是如此。因此,sQ 和 sdO 的共享内存空间有两个生产者。正因如此,需要为加载线程束(即将 Q, K, V, dO 从全局内存加载到共享内存的线程束)增加一个同步,以等待线程束完成存储 dK 和 dV,因为 dK, dV 使用与 dQ, dO 相同的共享内存空间,如果等待不当,可能会被加载线程束覆盖。

图 7. 反向传递中跨全局内存、共享内存和张量内存的数据流。
4. 基准测试
4.1 基准测试设置
本博文中的所有基准测试均在 Meta 内部集群的 NVIDIA B200 GPU(约 180 GB HBM, CUDA 13.0)上运行,每个 GPU 的功耗限制为 750 W,使用默认 GPU 时钟频率。
我们使用张量核心吞吐量 (TFLOPs) 和相对于基准实现的相对加速来报告内核级性能。我们评估了多种 GDPA 内核变体,包括:
- Triton GDPA (基准)。源自 Triton 模板 [6] 的基于 Triton 的实现,作为我们的主要基准。
- CUTLASS FMHA。目前在 Blackwell GPU 上运行的 CUTLASS 的 FMHA 内核。
- FA4 内核。FlashAttention-4 内核,我们在本博文中基于其设计并针对 GDPA 进行了调整。
- CuteDSL GDPA (我们的)。本工作中开发的优化 GDPA 内核,基于 CuteDSL 注意力内核,并针对实际 GDPA 训练负载进行了定制。
4.2 关键结果
我们评估了 GDPA 内核在自注意力和交叉注意力负载下的性能。在自注意力中,查询 (Query)、键 (Key) 和值 (Value) 张量由相同的嵌入生成,而在交叉注意力中,查询关注来自不同来源的键和值。由于自注意力是我们目标训练负载中的主导模式,我们首先详细展示自注意力的结果,随后是交叉注意力的结果。
为了表征输入的非规则性,我们改变了输入稀疏度,定义为批次内平均序列长度与最大序列长度之比。稀疏度 1.0 对应于完全密集的输入(所有序列长度相同),而较低的稀疏度值表示具有动态序列长度的锯齿状输入。例如,稀疏度 = 0.5 表示平均序列长度是最大长度的一半,反映了具有显著长度变化的真实训练流量。
如图 8 所示,在密集输入下,Blackwell FMHA 的表现明显优于 Triton GDPA 基准,吞吐量高出约 30%。然而,这种优势在锯齿状输入下会崩溃。当稀疏度降至 0.5 时,FMHA 和基准测试都会出现严重的性能衰退,且 FMHA 在所有序列长度上都退化到与基准测试几乎相同的水平。
相比之下,我们优化的 GDPA 内核在锯齿状输入下保持了鲁棒性。虽然所有内核在从密集形状过渡到锯齿状形状时都会变慢,但我们的内核始终保持了其密集输入性能的更大比例,即使在稀疏度 = 0.5 时,相对于基准仍保持约 1.7 倍的平均加速。这证明了面向密集的注意力内核无法推广到动态训练形状,而针对实际流量优化的 GDPA 设计在不规则工作负载下仍能保持高利用率。

图 8. 密集(左,稀疏度 = 1.0)和锯齿状(右,稀疏度 = 0.5)输入下的自注意力正向性能。我们优化的 GDPA 内核在两种情况下均持续保持更高的吞吐量。
我们还评估了自注意力反向性能。由于内存流量大和梯度累积,反向传递比正向传递更具挑战性。在密集输入下,优化后的 GDPA 内核实现了高达 1.6 倍的加速,达到约 702 BF16 TFLOPs,随着反向计算变得受带宽和同步限制,与 Blackwell FMHA 的差距缩小,如图 9 所示。在锯齿状输入下,尽管相对差异较小,但我们优化的内核相对于基准仍保持约 1.3 倍的平均加速。我们将此归因于反向传递的更高复杂度,这需要更精细的流水线调整。此外,由于反向内核通常受共享内存带宽限制,从 SDPA 到 GDPA 的转换提供的性能改进空间不如正向传递大,这为未来的进一步优化留下了潜力。

图 9. 密集(稀疏度 = 1.0)和锯齿状(稀疏度 = 0.5)输入下的自注意力反向性能。优化后的 GDPA 内核始终优于 Triton 基准,但由于反向计算受带宽和同步限制的特性,提升幅度比正向传递略为温和。
我们进一步评估了生产设置下具有短 K/V 序列 (K = V = 256) 的交叉注意力负载。在广泛的 Q 长度范围内,优化后的 GDPA 内核相比 Triton 基准实现了近 2 倍的正向和 1.6 倍的反向加速,而 Blackwell FMHA 的性能很快达到瓶颈,如图 10 所示。增加 Q 长度带来的增益有限,因为内核主要受限于内循环,且具有足够的外循环并行性。值得注意的是,在短 K/V 设置下,我们优化后的 GDPA 内核显著优于基准注意力内核,与 FA4 相比在正向传递中实现了高达 3.5 倍的加速,反向传递中实现了 1.6 倍的加速。这有效解决了 FA4 在具有短 K/V 序列的现实流量模式下出现的性能下降问题。

图 10. 短 K/V (256) 下的交叉注意力性能。我们优化后的 GDPA 在正向和反向中均随 Q 长度扩展,而 FA4 在短 K/V 机制下达到瓶颈。在实际生产流量设置下,相比 FA4,我们分别实现了 3.5 倍的正向和 1.6 倍的反向平均加速。
最后,我们在实际生产流量下评估了优化后的内核。如图 11 所示,与基准测试相比,正向和反向内核均实现了高达 2 倍的加速,性能明显接近基准测试设置中观察到的理论峰值。虽然仍存在差距,但我们将其主要归因于实际工作负载中极高的数据随机性,某些极端情况阻碍了完美均衡的执行。作为未来工作,我们计划进一步针对实际流量模式调整内核,以更好地缩小这一差距。

图 11. 实际生产与基准测试的内核性能对比。我们优化后的 GDPA 内核在正向和反向传递中均始终优于基准测试,实现了高达 2 倍的加速。
5. 结论
我们提出了一种针对推荐系统训练负载优化的生产驱动型 GDPA 内核设计。这项工作展示了如何在学术界开发的内核级增益转化为实际、可投入生产的改进。针对以锯齿状输入、短且不对称的 K/V 维度以及非 Softmax 激活为特征的工作负载,我们重构了内核流水线,引入了软件级块调度,并将计算从 SFU 中重新平衡。结果表明,我们优化的 GDPA 内核在 NVIDIA B200 GPU 上相对于 Triton 基准实现了高达 2 倍的正向和 1.6 倍的反向加速,并且在不规则的现实输入形状下显著优于 SOTA 注意力内核。当应用于整个模型中的多个模块时,这些定制内核提供了超过 30% 的端到端吞吐量提升,凸显了生产驱动型内核设计对现代 GPU 的影响。
致谢
我们要感谢 Tri Dao、Markus Hoehnerbach、Jay Shah、Ted Zadouri 和 Vijay。感谢他们的开源工作,这为我们的 GDPA 内核设计和本文中探索的许多想法奠定了基础。
参考文献
[1] Kunlun: 通过统一架构设计确立大规模推荐系统的扩展定律:https://arxiv.org/abs/2602.10016
[2] InterFormer: 用于点击率预测的高效异构交互学习:https://arxiv.org/abs/2411.09852
[3] Meta 的生成式广告模型 (GEM):加速广告推荐 AI 创新的中央大脑:https://engineering.fb.com/2025/11/10/ml-applications/metas-generative-ads-model-gem-the-central-brain-accelerating-ads-recommendation-ai-innovation/
[4] 行动胜于雄辩:用于生成式推荐的万亿参数序列转导器:https://arxiv.org/pdf/2402.17152
[5] FlashAttention-3: 具有异步和低精度的快速且精确的注意力机制:https://arxiv.org/pdf/2407.08608
[6] Triton 教程:06-fused-attention:https://triton-lang.cn/main/getting-started/tutorials/06-fused-attention.html