作者:Meta 的 Less Wright,IBM 的 Adnan Hoque

Float8 (FP8) 的二维块量化有望提高 Float8 量化的精度,同时加速用于推理和训练的 GEMM。在本博客中,我们将展示使用 Triton 在执行块量化 Float8 GEMM 的两个主要阶段所取得的进展。

对于将高精度 (BFloat16) 的 A 和 B 张量进行传入量化到 Float8,我们展示了 GridQuant,它利用迷你网格步长循环式处理方法,比当前二维块量化内核提速近 2 倍 (99.31%)。

对于 Float8 GEMM,我们展示了 Triton 的 3 项新发展——Warp 特化、TMA 和持久性内核,以有效创建协作式内核(Ping-Pong 调度的替代方案)。结果,我们比去年性能最佳的 SplitK 内核实现了约 1.2 倍 的加速。

Figure 1: A comparison of the 2D quantization speedup over a current baseline, across a range of sizes.

图 1: 在不同尺寸范围下,二维量化相对于当前基准的加速比较。(越低越好)

为什么对 FP8 进行二维块量化?

一般来说,随着我们从张量级缩放、行级缩放、二维块级缩放,最终到列级缩放,fp8 量化的精度会提高。这是因为给定 token 的特征存储在每一列中,因此该张量中的每一列缩放方式更相似。

为了最大限度地减少给定数值集中的异常值数量,我们希望找到共性,以便数字以类似的方式进行缩放。对于 transformer 模型,这意味着基于列的量化可能是最优的……然而,由于数据在内存中是按行连续布局的,列式内存访问效率极低。因此,列式加载将需要涉及内存中大步长访问才能提取孤立值,这与高效内存访问的核心原则相悖。

然而,二维块量化是次优选择,因为它包含列式量化的某些方面,同时由于我们可以使用二维向量化来向量化这些加载操作,因此更节省内存。因此,我们希望找到提高二维块量化速度的方法,这就是我们开发 GridQuant 内核的原因。

对于量化过程,我们需要对输入的更高精度 BF16 张量(A = 输入激活,B = 权重)进行二维块量化,然后使用量化后的张量及其二维块缩放值进行 Float8 矩阵乘法,并返回 BF16 格式的输出 C 张量。

GridQuant 如何提高二维块量化效率?

GridQuant 内核在最初基于标准分块实现的基线量化实现之上进行了多项改进。GridQuant 内核对整个输入张量进行两次完整遍历,其工作流程如下:

阶段 1 - 确定来自输入高精度张量的每个 256x256 子块的最大绝对值。

1 - 我们将 BF16 张量划分为 256 x 256 子块。此量化尺寸可配置,但默认为 256x256,因为它在量化精度和处理效率之间取得了平衡。

2 - 每个 256x256 子块被细分为 64 个子块,以 8x8 模式排列,每个子块处理一个 32x32 元素的块。一个 warp(32 个线程)负责处理其分配的 32x32 块内的所有元素。

3 - 我们在共享内存中声明一个 32x32 的 `max_vals` 数组。随着二维向量块在整个 256x256 子块中移动,这将存储每个位置 i,j 的当前最大值。

这是一项重要的改进,因为它意味着我们可以对最大值评分系统进行向量化而非标量更新,从而实现更高效的更新。

Figure 2: The Fractionalized layout of an incoming tensor - a grid of 256x256 is created across the tensor, and within each 256x256 block, it is further refined into 32x32 sub blocks. A 32x32 max_vals is created for each 256x256 block.

图 2: 输入张量的分块布局——在张量上创建一个 256x256 的网格,在每个 256x256 块内,进一步细化为 32x32 子块。为每个 256x256 块创建一个 32x32 的 `max_vals`。

4 - 每个 warp 处理一个 32x32 的块,并且由于我们使用了 4 个 warps,我们确保 Triton 编译器可以将下一个 32x32 块的内存加载与当前块的 absmax 计算处理进行流水线操作。这确保了 warp 调度器能够在加载数据的 warps 和进行处理的 warps 之间切换,并使 SM 持续保持忙碌。

5 - 32x32 的二维向量块处理以网格步长循环的方式在整个 256x256 子块中移动,每个 warp 根据其当前的 32x32 子块更新共享内存中的 32x32 `max_vals`。因此,随着每个子块的处理,`max_vals[i,j]` 保存最新的最大值。

完成 256x256 块的网格步长循环后,`max_vals` 矩阵本身会被规约,以找到整个 256 块的绝对单个最大值。

这为我们的二维 256 x 256 块提供了最终的缩放因子值。

阶段 2 - 使用阶段 1 中找到的单个最大值缩放因子,将 256x256 块的值量化为 Float8。

接下来,我们对整个 256x256 块进行第二次遍历,使用阶段 1 中找到的这个最大值重新缩放所有数字,将它们转换为 float 8 格式。

由于我们知道需要进行两次完整的遍历,对于阶段 1 部分的加载,我们指示 triton 编译器以更高的优先级(逐出策略 = last)将这些值保留在缓存中。

这意味着在第二次遍历期间,我们可以从 L2 缓存获得高命中率,这比一直访问 HBM 提供快得多。

当所有 256x256 块处理完成后,二维块量化处理就完成了。我们可以返回新的 Float8 量化张量及其缩放因子矩阵,这些将在 GEMM 处理的下一阶段使用。此输入量化也会对第二个输入张量重复进行,这意味着我们最终会得到 A_Float8、A_scaling_matrix 以及 B_Float8 和 B_scaling_matrix。

GridQuant - GEMM 内核

GridQuant-GEMM 内核接收上述量化的四个输出进行处理。我们的高性能 GEMM 内核具有多项新的 Triton 发展,可在 LLM 推理的解码阶段实现与矩阵形状相关的 SOTA 性能。

这些新功能常见于 Hopper 优化的内核中,例如使用 CUTLASS 3.x 构建的 FlashAttention-3Machete。在此,我们将讨论这些方法,并展示如何在 Triton 中利用它们实现性能优势。

张量内存加速器 (TMA)

NVIDIA Hopper GPU 上的 TMA 单元是专门用于对 AI 工作负载中常见的多维张量进行加载/存储操作的硬件单元。这带来了多项重要的好处。

从全局内存和共享内存传输数据可以在不涉及 GPU SM 上其他资源的情况下发生,从而释放寄存器和 CUDA 核心。此外,当用于 warp 特化内核时,可以将轻量级 TMA 操作分配给生产 warp,从而实现内存传输和计算的高度重叠。

有关 TMA 在 Triton 中如何使用的更多详细信息,请参阅我们的上一篇博客

Warp 特化(协作式持久性内核设计)

Warp 特化是一种利用 GPU 上的流水线并行性的技术。这个实验性功能通过 tl.async_task API 实现特化线程的表达,允许用户指定 Triton 程序中的操作应如何在 warps 之间“拆分”。协作式 Triton 内核执行不同类型的计算和加载操作,这些操作都在其各自的专用硬件上进行。为每个特化任务配备专用硬件,可以高效地实现无数据依赖性操作的并行性。

Figure 3. Logical view of dedicated HW units in NVIDIA H100 SM

图 3。 NVIDIA H100 SM 中专用硬件单元的逻辑视图

我们内核中创建流水线的操作包括:

A - 将每块缩放因子从 GMEM 加载到 SMEM (cp.async 引擎)

B - 将激活 (A) 和权重 (B) 块从 GMEM 加载到 SMEM (TMA)

C - A 块和 B 块的矩阵乘法 = C 块 (Tensor Core)

D - 使用来自 A 和 B 的每块缩放因子对 C 块进行缩放 (CUDA core)

这些步骤可以分配给“任务”,由线程块中特化的 warp 组执行。协作策略有三个 warp 组。一个负责向计算单元提供数据的生产者 warp 组,以及两个执行计算的消费者 warp 组。两个消费者 warp 组各自处理同一个输出块的一半。

Figure 4. Warp-Specialized Persistent Cooperative kernel

图 4。 Warp 特化持久性协作式内核(来源:NVIDIA

这与我们在上一篇博客中讨论的 ping-pong 调度不同,ping-pong 调度中每个消费者 warp 组处理不同的输出块。我们注意到 Tensor Core 操作与后续计算(epilogue computation)没有重叠。与始终保持 Tensor Core 忙碌的 ping-pong 调度相比,计算后续阶段 Tensor Core 流水线利用率的降低将减少消费者 warp 组的寄存器压力,从而允许更大的块尺寸。

最后,当网格尺寸超过 H100 GPU 上可用计算单元的数量 (132) 时,我们的内核被设计为持久性的。持久性内核在 GPU 上长时间保持活动状态,并在其生命周期内计算多个输出块。我们的内核利用 TMA 异步共享到全局内存存储,同时继续处理下一个输出块,而不是产生调度多个线程块的开销。

微基准测试

Figure 5: Latency comparison (us) of Gridquant-GEMM vs our best performing SplitK kernel for small batch regime and Llama3 8192 N,K sizing.

图 5: 在小批量模式和 Llama3 8192 N,K 尺寸下,Gridquant-GEMM 与我们性能最佳的 SplitK 内核的延迟比较 (微秒)。(越低越好)

Warp 特化的 Triton 内核在上述小 M 和方阵形状下实现了 SOTA 性能,比 SplitK Triton 内核提速近 1.2 倍,SplitK Triton 内核是 Triton 在此低算术强度情况下的先前最佳性能策略。未来的工作计划包括调整我们的内核性能,使其适用于中到大 M 情况和非方阵。

结论和未来工作

未来工作包括在端到端工作流程上对 GridQuant 进行基准测试。此外,我们计划对非方阵(矩形)以及中到大 M 尺寸进行更广泛的基准测试。最后,我们计划探索 Triton 中的 ping-pong 式 warp 特化与当前的协作实现之间的差异。