博客

一些矩阵乘法引擎的精度不如我们想象的那么高

加速器 GEMM 引擎中的累加器是什么?它为何如此重要?

GPU 和定制加速器包含用于矩阵乘法(也称为 matmul 或 GEMM)的专用计算引擎,例如 NVIDIA 的 Tensor Core。这些引擎能高效地对小型张量块进行矩阵乘法运算;因此,编译器或库通常会将大型矩阵乘法问题分解为许多较小的问题,并将其输入到这些引擎中。通常,由 FP8 (e4m3) 格式且形状为 (block_size_m, block_size_k) 和 (block_size_k, block_size_n) 的矩阵乘法计算出的 Tensor Core 输出,是一个 FP32 (e8m23) 格式的 (block_size_m, block_size_n) 张量。然而,一个用户很少注意到但十分有趣的现象是,出于硬件效率的考虑,这个 FP32 输出的有效尾数位(mantissa bits)可能少于 23 位。换句话说,该 Tensor Core 操作的精度表面上是 FP32,实则更低。据报道,这种硬件设计选择在特定情况下会影响模型精度 1, 2。因此,从 GPU 用户的角度来看,我们希望验证正在使用的硬件设计。因为即使无法改变现有硬件,在需要时仍然可以通过正确编写自定义内核来保持所能达到的最高精度。对于硬件设计者而言,有一种方便且高效的方法来量化这种设计选择的影响同样至关重要。

在深入探讨细节之前,我们需要理解“累加器”的作用以及采用降低精度的原因。首先,让我们考虑一个假设的计算引擎,它可以处理块大小为 (3, 4) 和 (4, 3) 的 FP8 矩阵乘法,如图 1a 所示。放大来看,该计算引擎最基本的操作是行-列内积,即:

cᵢⱼ = ∑ₖ aᵢₖ * bₖⱼ。可以想象,一种高效的硬件设计将简单地实现 4 个乘法器来计算每一对 aik, bkj,随后使用 3 个加法器将中间结果相加,如图 1b 所示。在这个简单的例子中,我们可以看到,如果乘法器数量充足,乘法部分可以在一个并行化的“计算步骤”中完成。但加法部分需要 2 个计算步骤才能完成,因为它必须以分层、串行的方式进行。如果我们为 N 个元素扩展这种单元设计,乘法仍只需一步,而加法将需要 log(N) 步。

此外,每个乘法器只需要计算 FP8 * FP8 (e4m3),这涉及 4 位 + 4 位的加法(用于指数)和 4 位 x 4 位的乘法(用于尾数)。然而,由于每个部分积需要正确对齐,后续的加法器必须使用比乘法器多得多的位数。如图 2 所示(仅作示例,并非真正的 FP8 情况),将两个仅有 4 位尾数的有限精度浮点数相加,最终得到的浮点数可能需要更多的尾数位。这从侧面解释了为什么浮点乘加 (MAC) 操作的电路复杂度和成本(硅片面积和功耗)在很大程度上取决于累加精度。因此,即使使用 FP32 作为累加精度更安全(图 2b),探索使用降低累加精度的机会也是值得的。

考虑到这些例子,在矩阵乘法引擎中使用降低精度的加法器的优势就显而易见了。

如何验证累加器精度?(以 TensorCore 为例)

鉴于矩阵乘法累加器的设计可能少于 23 位尾数,实际输出实际上是 e8mNacc (其中 Nacc < 23),并将尾部的 0 填充至 e8m23。换句话说,FP8 TensorCore 的输出看起来可能像是 FP32,但在计算过程中,小于 e8mNacc 的部分从未被计算过。在本篇博客中,我们将演示一种使用 Triton 内核研究累加器精度的简单方法。

假设 TensorCore 输出只有 Nacc 个有效尾数位(如 e8mNacc 中所示),即最后 23 − Nacc 位已经是 0。如果我们应用掩码来截断 TensorCore 输出的最后 Ntrun 位,只要 Ntrun ≤ 23 − Nacc,最终的矩阵乘法结果应该保持不变。此外,通过扫描 Ntrun 并将矩阵乘法输出与参考值(即 Ntrun = 0)进行比较,我们可以推断出正在研究的 FP 矩阵乘法单元的累加器精度。此处,“截断 Ntrun 位”是指将浮点数最后 Ntrun 位(即尾数的最低有效位 LSB)清零。

为什么选择 Triton?

我们使用 triton 语言是因为它允许将所提出的方法推广到其他支持 Triton 的加速器上。由于其简单性以及对加速器提供了恰到好处的控制能力,它也极大地加快了此实验的开发速度。尽管 Triton 预计会随着时间不断演进,但由于我们的实现基于 Triton 的矩阵乘法教程,我们预计未来需要重写的特定代码将非常少。

实验

本笔记本最后提供了可运行的代码。在此,我们采用了来自 Triton 教程 的 Triton 矩阵乘法内核,并添加了一个简单的截断函数。由于原始教程中有大量细节,我们仅重点介绍所做的与截断相关的修改。粗略地说,matmul(A, B) 被分解为较小的块并并行处理。A 和 B 的每个块的形状分别为 (BLOCK_SIZE_M, BLOCK_SIZE_K)(BLOCK_SIZE_K, BLOCK_SIZE_N)。块级矩阵乘法由 Triton 的 tl.dot() 函数计算,产生一个形状为 (BLOCK_SIZE_M, BLOCK_SIZE_N) 的临时张量 accumulator_inner,该张量被假定只有 Nacc 个有效尾数位。

  1. accumulator_inner 的截断:我们使用带有预定义掩码的位操作截断了 accumulator_inner 的最后 Ntrun 位。为简单起见,我们将 round_bit 设置为 0,从而忽略舍入。
def prep_round_and_trun_mask(trun_bits):
        round_bit = 1 << (trun_bits - 1) if trun_bits > 0 else 0
        trun_mask = ~tl.cast((1 << trun_bits) - 1, tl.uint32)
        return round_bit, trun_mask
def round_and_trun(x, round_bit, trun_mask):
        """Round and truncate (usually for accumulator)."""
        return libdevice.uint_as_float(
            (libdevice.float_as_uint(x) + round_bit) & trun_mask
 )

2. 跨 K 维度的累加:在遍历 K 维度时,每个截断后的 accumulator_inner 会被进一步累加到预先分配的 FP32 张量累加器中。accumulator 的形状与 accumulator_inner 相同。

3. 结果回写:在遍历完 K 维度后,最终的 accumulator 值会被写回目标输出张量 C 的相应块中,张量 C 的形状为 (M, N)

结果与讨论

从下方的表 1 和图 3 可以观察到,截断(使用 H100 FP8 TensorCore)输出的最后 10 位有效尾数所得的结果,与不进行截断的情况完全相同。这表明这些位在原始输出中已经是 0。因此,该实验表明,出于计算效率的考虑,累加器采用了特殊的 FP22 格式 (e8m13) 实现。我们在 RTX4000 系列 GPU(Ada Lovelace 架构)上重复了同样的实验,观察到了相同的行为。

我们应该记住的一个重要考量是,该实验依赖于 Triton 编译器将 Triton 代码翻译成等效的 CUDA 代码。我们必须确保执行任务的 TensorCore 确实是我们打算检查的那个,即 FP8。在极少数情况下,Triton 编译器可能会选择对某些 FP8 计算使用 FP16 TensorCore 指令。确认实际执行的硬件指令最可靠的方法是使用 NVIDIA 分析器 ncu(3,包含在 cudatoolkit 中),以检查与 Triton tl.dot 调用相关联的底层 CUDA 指令。

读者可以将此笔记本保存为 python 文件,然后使用以下命令行调用来启动 ncu

/usr/local/cuda-13.0/bin/ncu --target-processes all --set full 
--import-source yes -f --kernel-name matmul_kernel --launch-skip 3 
--launch-count 1 -o ./tl_fp8mm_backend_H100 python 
accumulator_precision_test.py

从下方显示的 ncu 分析器读取结果中,我们发现所选块大小(MxNxK=64x64x32)的 FP8xFP8 tl.dot() 被翻译为一条 QGMMA 指令——这是一种 FP8-TensorCore 专有的指令。这证实了确实使用了 FP8 TensorCore。

如前所述,Triton 编译器有时会为 tl.dot 选择不同的实现方式。例如,如果我们设置 num_warps = 2 并重复该实验,Triton 将把 FP8 封装为 FP16 并使用 HMMA 来执行计算,其中 HMMA 是 FP16-TensorCore 专有的指令。在这种情况下,相应的结果显示 FP16 TensorCore 的累加器仅比 FP32 短 1 位。

此外,由于专用矩阵乘法单元旨在处理特定固定大小的输入,如果我们选择的 BLOCK_SIZE 超过了 TensorCore 的处理能力,编译器或 CUDA 库会自动将该操作分解为多个较小的操作。在我们的 Triton 代码中,我们可以将 BLOCK_SIZE K 增加到 128 并再次使用 ncu 进行验证。我们将看到每条 WGMMA 指令只能处理 K=32,这意味着还需要额外的求和来组合来自多个 TensorCore 调用的部分结果。一个自然的问题是:这种中间求和使用了什么精度?这就是我们一直在讨论的 FP 对齐和精度损失问题。基于 K=128 实验的输出,我们仍然观察到 13 位有效尾数。这提供了一个重要的见解:如果我们为 Triton 内核选择的块大小超过了 TensorCore 的基础设计(无论是出于性能原因还是自动调优),可能会因为低精度求和而导致额外的精度损失。因此,如果矩阵乘法精度是一个关键问题(尤其是在涉及训练和反向传播时),在回退到 FP16 之前,我们应该首先尝试像我们在 Triton 代码中所做的那样使用中间 FP32 累加。我们在此演示了 BLOCK_SIZE_K 对精度的影响,但读者应记住,较小的块会影响内核性能。读者可能希望从较大的块大小开始,例如如果自动调优建议 256 或 512,则逐渐减小到 128(如文献 1 所述),并考虑使用 FP16 与减小块大小之间的权衡。或者,如果在自定义内核中使用 cuBLAS,CUBLASLT_MATMUL_DESC_FAST_ACCUM 标志可以达到同样的提升累加精度的效果。4

最后,低精度累加器的概念也可以应用于 INT8xINT8 引擎。FP8 和 INT8 矩阵乘法的主要区别在于,INT8 累加器截断发生在最高有效位 (MSB) 而不是最低有效位 (LSB)。换句话说,我们需要考虑溢出问题,而不是像 FP8 那样考虑下溢。可以对提供的 Triton 内核进行简单的修改来研究 INT8 的行为,我们将这个练习留给感兴趣的读者。

结论

我们解释了在矩阵乘法引擎的累加器中使用降低精度的重要性,并演示了一种验证现有加速器设计的简单方法。对于编写自定义内核且对精度敏感的应用的用户,以及需要为下一代设计模拟这种行为的硬件设计者来说,理解累加器精度至关重要。更重要的是,这种基于 Triton 内核的方法可以与 PyTorch 生态系统无缝结合,这意味着该技术可以扩展到其他支持 Triton 语言的现有和未来加速器,从而显著减少开发时间。

参考

  1. DeepSeek-V3 技术报告,第 3.3.2 节:提高累加精度。 https://arxiv.org/html/2412.19437v1.
  2. SageAttention2,简介/挑战/C2。 https://arxiv.org/html/2411.10958v7
  3. ncu 网站 https://docs.nvda.net.cn/nsight-compute/index.html
  4. https://docs.nvda.net.cn/cuda/cublas/

可运行代码可在此处找到

https://gist.github.com/chichun-charlie-liu/88a99949fcbe589aa5f71e48616ac944