Float8 (FP8) 的 2D 块量化有望在提高 Float8 量化精度的同时,为推理和训练任务加速通用矩阵乘法 (GEMM)。在本篇博客中,我们将展示如何利用 Triton 实现块量化 Float8 GEMM 的两个主要阶段。
针对 A 和 B 张量从高精度 (BFloat16) 到 Float8 的输入量化,我们推出了 GridQuant。它利用微网格步长循环 (mini-grid stride loop) 处理风格,相较于目前的 2D 块量化内核,速度提升了近 2 倍 (99.31%)。
对于 Float8 GEMM,我们展示了 Triton 的 3 项新进展——Warp 专用化 (Warp Specialization)、TMA 和持久化内核 (persistent kernel),从而有效地创建了一种协作式内核(这是 Ping-Pong 调度的一种替代方案)。最终,相较于我们去年性能最好的 SplitK 内核,实现了约 1.2 倍 的速度提升。

图 1:2D 量化在不同尺寸下相对于当前基准的速度提升对比。(数值越低越好)
为何 FP8 要采用 2D 分块量化?
总的来说,fp8 量化的精度会随着量化粒度的细化而提升,顺序为:张量级缩放 -> 行级缩放 -> 2D 块级缩放 -> 列级缩放。这是因为给定 token 的特征存储在每一列中,因此张量中的每一列具有更相似的缩放比例。
为了尽量减少给定数值集中的离群值,我们需要寻找共性,以便以相似的方式对数字进行缩放。对于 Transformer 模型而言,这意味着基于列的量化可能是最优的……然而,由于数据在内存中是以行连续的方式排列的,按列读取内存的效率极低。因此,列式加载需要在大步长内存访问中获取孤立值,这与高效内存访问的核心原则背道而驰。
然而,2D 是次优选方案,因为它既包含列式缩放的部分特性,又因为可以使用 2D 向量化加载而具有更高的内存访问效率。因此,我们希望寻找提高 2D 块量化速度的方法,这也是我们开发 GridQuant 内核的原因。
对于量化过程,我们需要对高精度的 BF16 输入张量(A = 输入激活值,B = 权重)进行 2D 块量化,然后使用量化后的张量及其 2D 块缩放值进行 Float8 矩阵乘法,最后返回 BF16 格式的输出张量 C。
GridQuant 如何提高 2D 块量化的效率?
与标准的基于瓦片 (tile) 的初始基准量化实现相比,GridQuant 内核有几项改进。GridQuant 内核对整个输入张量进行了两次完整遍历,工作流程如下:
阶段 1 – 从输入的高精度张量中确定每个 256×256 子块的最大绝对值。
1 – 我们将 BF16 张量划分为 256×256 的子块。该量化尺寸可配置,但 256×256 是默认设置,因为它在量化精度和处理效率之间取得了平衡。
2 – 每个 256×256 子块被细分为 64 个 8×8 排列的微型子块,每个微型子块处理一个 32×32 的元素块。一个线程束 (Warp, 32 个线程) 负责处理其分配的 32×32 块内所有元素的计算。
3 – 我们在共享内存中声明一个 32×32 的 max_vals 数组。当 2D 向量块在整个 256×256 子块中移动时,该数组将存储每个位置 i,j 的当前最大值。
这是一个重要的改进,因为它意味着我们可以对最大值评分系统进行向量化更新,而不是标量更新,从而实现更高效的更新。

图 2:输入张量的分块布局——在张量上创建一个 256×256 的网格,并在每个 256×256 块内进一步细分为 32×32 的子块。为每个 256×256 块创建一个 32×32 的 max_vals。
4 – 每个线程束处理一个 32×32 的块,由于我们使用 4 个线程束,我们确保 Triton 编译器可以将下一个 32×32 块的内存加载与当前块的 absmax 计算过程进行流水线处理。这确保了线程束调度器能够交替执行加载数据的线程束和处理计算的线程束,从而保持流处理器 (SM) 持续忙碌。
5 – 32×32 的 2D 向量块处理过程以网格步长循环的方式在整个 256×256 子块中移动,每个线程束根据其当前的 32×32 子块更新共享内存中的 32×32 max_vals。因此,随着每个子块的处理,max_vals[i,j] 会保存最新的最大值。
在完成 256×256 块的网格步长循环后,maxvals 矩阵本身会被归约,以找到整个 256 块的单个绝对最大值。
这为我们提供了该 2D 256×256 块的最终缩放因子。
阶段 2 – 使用阶段 1 找到的单个最大值缩放因子,将 256×256 块的值量化为 Float8。
接下来,我们对整个 256×256 块进行第二次遍历,使用阶段 1 找到的最大值重新缩放所有数字,将其转换为 float8 格式。
因为我们知道需要进行 2 次完整遍历,所以在阶段 1 的加载过程中,我们指示 Triton 编译器以更高的优先级将这些值保留在缓存中(驱逐策略 = 最近最少使用/Last)。
这意味着在第二次遍历期间,我们可以从 L2 缓存中获得很高的命中率,这比直接访问 HBM 提供快得多的内存访问速度。
当所有 256×256 块处理完毕,2D 块量化处理即告完成。我们返回新的 Float8 量化张量及其缩放因子矩阵,这些将在 GEMM 处理的下一阶段使用。此输入量化过程也会在第二个输入张量上重复,最终我们得到 A_Float8、A_scaling_matrix 以及 B_Float8 和 B_scaling_matrix。
GridQuant – GEMM 内核
GridQuant-GEMM 内核接收上述量化的四个输出进行处理。我们的高性能 GEMM 内核采用了多项新的 Triton 开发成果,旨在为 LLM 推理解码阶段相关的矩阵形状配置实现业界领先 (SOTA) 的性能。
这些新特性常见于 Hopper 优化内核(如 FlashAttention-3 和 Machete),并基于 CUTLASS 3.x 构建。在此,我们讨论这些方法,并展示在 Triton 中利用它们所能带来的性能收益。
张量内存加速器 (TMA)
NVIDIA Hopper GPU 上的 TMA 单元是一个专用的硬件单元,用于处理 AI 工作负载中常见的多维张量的加载/存储操作。这有几个重要的好处。
从全局内存到共享内存的数据传输可以在不占用 GPU SM 上其他资源的情况下进行,从而释放寄存器和 CUDA 核心。此外,在采用线程束专用化 (warp-specialized) 的内核中使用时,轻量级的 TMA 操作可以分配给生产者线程束,从而实现内存传输与计算的高度重叠。
有关如何在 Triton 中使用 TMA 的更多详细信息,请参阅我们的往期博客。
线程束专用化 (协作式持久内核设计)
线程束专用化是一种利用 GPU 流水线并行性的技术。该实验性功能通过 tl.async_task API 实现专用线程的表达,允许用户指定 Triton 程序中的操作应如何在线程束之间“分配”。协作式 Triton 内核执行不同类型的计算和加载,每种操作都在其各自的专用硬件上进行。为这些专用任务配备专用硬件,使得针对无数据依赖的操作实现高效并行成为可能。

图 3. NVIDIA H100 SM 中专用硬件单元的逻辑视图
我们内核中创建流水线的操作包括:
A – 将块缩放因子从 GMEM 加载到 SMEM (cp.async 引擎)
B – 将激活值 (A) 和权重 (B) 的瓦片从 GMEM 加载到 SMEM (TMA)
C – A 瓦片和 B 瓦片的矩阵乘法 = C 瓦片 (Tensor Core)
D – 使用来自 A 的块缩放因子和来自 B 的块缩放因子对 C 瓦片进行缩放 (CUDA 核心)
这些步骤可以分配给由线程块中的专用线程束组执行的“任务”。这种协作策略包含三个线程束组:一个负责向计算单元供数据的生产者线程束组,以及两个执行计算的消费者线程束组。两个消费者线程束组各自处理同一个输出瓦片的一半。

图 4. 线程束专用化的持久协作内核(来源:NVIDIA)
这与我们在前一篇博客中讨论的 Ping-Pong 调度不同,Ping-Pong 调度中每个消费者线程束组处理的是不同的输出瓦片。我们注意到,Tensor Core 操作与结尾处理 (epilogue) 计算并未重叠。与始终让 Tensor Core 保持忙碌的 Ping-Pong 相比,在计算的结尾阶段降低 Tensor Core 流水线的利用率,会减少消费者线程束组的寄存器压力,从而允许使用更大的瓦片尺寸。
最后,我们的内核设计为持久化内核,当网格大小超过 H100 GPU 上的可用计算单元数量 (132) 时使用。持久化内核在 GPU 上长时间保持活跃,并在其生命周期内计算多个输出瓦片。我们的内核利用 TMA 异步共享内存到全局内存的存储,同时继续处理下一个输出瓦片,而不是产生调度多个线程块的开销。
微基准测试

图 5:在小批量环境下,针对 Llama3 8192 N,K 规模,GridQuant-GEMM 与我们性能最好的 SplitK 内核的延迟 (us) 对比。(数值越低越好)
该线程束专用化的 Triton 内核在上述小 M 值和正方形矩阵形状下达到了 SOTA 性能,相比 Triton 之前的最佳策略(SplitK 内核,在低算术强度范畴内),实现了近 1.2 倍 的速度提升。对于后续工作,我们计划针对中等到大 M 范畴以及非正方形矩阵调优内核性能。
结论和未来工作
未来的工作包括在端到端工作流中对 GridQuant 进行基准测试。此外,我们计划对非正方形(矩形)矩阵以及中到大 M 尺寸进行更广泛的基准测试。最后,我们计划探索 Triton 中 Ping-Pong 式的线程束专用化与当前协作式实现之间的差异。