跳转到主要内容
博客

在 Triton 中加速 2D 动态块量化 Float8 GEMM

作者: 2024 年 12 月 6 日2025 年 5 月 5 日暂无评论

Float8 (FP8) 的 2D 块量化有望提高 Float8 量化的准确性,同时加速推理和训练中的 GEMM。在本博客中,我们将展示如何使用 Triton 在块量化 Float8 GEMM 所涉及的两个主要阶段中取得进展。

对于将 A 和 B 张量从高精度 (BFloat16) 量化为 Float8,我们展示了 GridQuant,它利用微型网格步长循环处理方式,比当前的 2D 块量化内核实现了近 2 倍的加速 (99.31%)。

对于 Float8 GEMM,我们展示了 Triton 的三项新开发——Warp 专用化、TMA 和持久化内核,以有效地创建协作式内核(替代 Ping-Pong 调度)。因此,我们比去年表现最佳的 SplitK 内核实现了约 1.2 倍的加速。

Figure 1: A comparison of the 2D quantization speedup over a current baseline, across a range of sizes.

图 1: 2D 量化在不同尺寸范围内的加速与当前基线的比较。(越低越好)

为什么选择 FP8 的 2D 块量化?

一般来说,FP8 量化的精度随着我们从张量级缩放、到行级缩放、到 2D 块级,再到最终的列级缩放而提高。这是因为给定令牌的特征存储在每一列中,因此该张量中的每一列都具有更相似的缩放比例。

为了最小化给定数值集中的离群值数量,我们希望找到共性,以便数值以相似的方式进行缩放。对于 Transformer 来说,这意味着基于列的量化可能是最优的……然而,由于数据在内存中以行式连续的方式布局,列式内存访问效率极低。因此,列式加载将需要涉及内存中的大步长内存访问,以提取孤立值,这与高效内存访问的核心原则相悖。

然而,2D 是次优选择,因为它包含列式的一些方面,同时由于我们可以通过 2D 向量化来向量化这些加载,因此内存效率更高。因此,我们希望找到提高 2D 块量化速度的方法,这就是我们开发 GridQuant 内核的原因。

对于量化过程,我们需要对传入的高精度 BF16 张量(A = 输入激活,B = 权重)进行 2D 块量化,然后使用量化张量及其 2D 块缩放值进行 Float8 矩阵乘法,并返回 BF16 格式的输出 C 张量。

GridQuant 如何提高 2D 块量化效率?

GridQuant 内核比最初的基线量化实现(标准的基于瓦片的实现)有几项改进。GridQuant 内核对整个输入张量进行了两次完整遍历,其工作方式如下:

阶段 1 – 从传入的高精度张量中确定每个 256x256 子块的最大绝对值。

1 – 我们将 BF16 张量划分为 256 x 256 子块。这个量化大小是可配置的,但 256x256 是默认值,因为它在量化精度和处理效率之间提供了平衡。

2 – 每个 256x256 子块进一步细分为 64 个子块,以 8x8 模式排列,每个子块处理一个 32x32 元素块。单个 Warp(32 个线程)处理其指定 32x32 块内所有元素的计算。

3 – 我们在共享内存中声明一个 32x32 的 max_vals 数组。当 2D 向量块在整个 256x256 子块中移动时,它将存储每个位置 i,j 的当前最大值。

这是一个重要的改进,因为它意味着我们可以对最大值评分系统进行向量化(而不是标量)更新,并允许更高效的更新。

Figure 2: The Fractionalized layout of an incoming tensor - a grid of 256x256 is created across the tensor, and within each 256x256 block, it is further refined into 32x32 sub blocks. A 32x32 max_vals is created for each 256x256 block.

图 2: 传入张量的分块布局——在张量上创建 256x256 的网格,在每个 256x256 块内,进一步细化为 32x32 子块。为每个 256x256 块创建 32x32 的 max_vals。

4 – 每个 Warp 处理一个 32x32 块,并且由于我们使用了 4 个 Warp,我们确保 Triton 编译器可以对下一个 32x32 块的内存加载与当前块的 absmax 计算的实际处理进行流水线化。这确保了 Warp 调度器能够切换加载数据的 Warp 与正在处理数据的 Warp,并保持 SM 持续忙碌。

5 – 32x32 2D 向量块处理以网格步进循环的方式在整个 256x256 子块中移动和穿过,每个 Warp 根据其当前的 32x32 子块更新共享内存 32x32 max_vals。因此,max_vals[i,j] 存储了每个子块处理后的最新最大值。

完成 256x256 块网格步进循环后,max_vals 矩阵本身也会被归约,以找到整个 256 块的绝对单最大值。

这为我们提供了这个 2D 256 x 256 块的最终缩放因子值。

阶段 2 – 通过使用在阶段 1 中找到的单一最大值缩放因子,将 256x256 块值量化为 Float8。

接下来,我们对整个 256x256 块进行第二次遍历,使用在阶段 1 中找到的最大值重新缩放所有数字,以将其转换为 float 8 格式。

因为我们知道需要进行 2 次完整的遍历,所以在阶段 1 部分的加载过程中,我们指示 Triton 编译器以更高的优先级(逐出策略 = last)将这些值保留在缓存中。

这意味着在第二次遍历期间,我们可以从 L2 缓存获得较高的命中率,这提供了比一直访问 HBM 更快的内存访问速度。

当所有 256x256 块都处理完毕,2D 块量化处理完成时,我们可以返回新的 Float8 量化张量及其缩放因子矩阵,我们将在 GEMM 处理的下一阶段使用它们。这种输入量化也对第二个输入张量重复进行,这意味着我们最终得到 A_Float8、A_scaling_matrix 以及 B_Float8 和 B_scaling_matrix。

GridQuant – GEMM 内核

GridQuant-GEMM 内核接收来自上述量化的四个输出进行处理。我们的高性能 GEMM 内核具有几项新的 Triton 开发,可在 LLM 推理解码阶段相关的矩阵形状配置文件中实现 SOTA 性能。

这些新功能常见于 Hopper 优化内核,如 FlashAttention-3Machete,它们使用 CUTLASS 3.x 构建。在这里,我们讨论这些方法并展示在 Triton 中利用它们可以实现的性能优势。

张量内存加速器 (TMA)

NVIDIA Hopper GPU 上的 TMA 单元是一个专用于加载/存储操作的硬件单元,它作用于 AI 工作负载中常见的多维张量。这有几个重要的好处。

数据从全局内存和共享内存传输可以在不涉及 GPU SM 上其他资源的情况下发生,从而释放寄存器和 CUDA 核心。此外,当在 Warp 专用内核中使用时,轻量级 TMA 操作可以分配给生产者 Warp,从而实现内存传输和计算的高度重叠。

有关 TMA 如何在 Triton 中使用的更多详细信息,请参阅我们的上一篇博客

Warp 专用化(协作式持久化内核设计)

Warp 专用化是一种利用 GPU 上流水线并行性的技术。这项实验性功能通过 tl.async_task API 实现了专用线程的表达,允许用户指定 Triton 程序中的操作应如何在 Warp 之间“分割”。协作式 Triton 内核执行不同类型的计算和加载,每种都在其自己的专用硬件上进行。为每个专用任务配备专用硬件使得在没有数据依赖性的操作中高效实现并行成为可能。

Figure 3. Logical view of dedicated HW units in NVIDIA H100 SM

图 3. NVIDIA H100 SM 中专用硬件单元的逻辑视图

我们内核中创建流水线操作是:

A – 将每块缩放从 GMEM 加载到 SMEM (cp.async engine)

B – 将激活 (A) 和权重 (B) 瓦片从 GMEM 加载到 SMEM (TMA)

C – A 瓦片和 B 瓦片的矩阵乘法 = C 瓦片 (Tensor Core)

D – 用 A 的每块缩放和 B 的每块缩放缩放 C 瓦片 (CUDA Core)

这些步骤可以分配给由线程块中专用 Warp 组执行的“任务”。协作策略有三个 Warp 组。一个生产者 Warp 组负责为计算单元提供数据,两个消费者 Warp 组执行计算。这两个消费者 Warp 组各自处理同一输出瓦片的一半。

Figure 4. Warp-Specialized Persistent Cooperative kernel

图 4. Warp 专用持久化协作内核(来源:NVIDIA

这与我们之前的博客中讨论的 ping-pong 调度不同,在 ping-pong 调度中,每个消费者 Warp 组处理不同的输出瓦片。我们注意到 Tensor Core 操作与结尾计算没有重叠。与 ping-pong 始终保持 Tensor Core 忙碌相比,在计算的结尾阶段降低 Tensor Core 流水线利用率将减少消费者 Warp 组的寄存器压力,从而允许更大的瓦片大小。

最后,当网格大小超过 H100 GPU 上可用计算单元的数量(132)时,我们的内核被设计为持久化。持久化内核在 GPU 上长时间保持活动状态,并在其生命周期内计算多个输出瓦片。我们的内核利用 TMA 异步共享到全局内存存储,同时继续处理下一个输出瓦片,而不是承担调度多个线程块的成本。

微基准测试

Figure 5: Latency comparison (us) of Gridquant-GEMM vs our best performing SplitK kernel for small batch regime and Llama3 8192 N,K sizing.

图 5: Gridquant-GEMM 与我们性能最佳的 SplitK 内核在小批次和 Llama3 8192 N,K 尺寸下的延迟比较 (us)。(越低越好)

Warp 专用 Triton 内核在上述小 M 和方形矩阵形状上实现了 SOTA 性能,比 SplitK Triton 内核(在该低算术强度体系中 Triton GEMM 之前表现最佳的策略)实现了近 1.2 倍的加速。对于未来的工作,我们计划调整内核性能,以适应中大型 M 体系和非方形矩阵。

结论和未来工作

未来的工作包括对 GridQuant 进行端到端工作流的基准测试。此外,我们计划对非方形(矩形)矩阵以及中大型 M 尺寸进行更广泛的基准测试。最后,我们计划探索 Triton 中 ping-pong 风格的 Warp 专用化与当前的协作实现。