作者:IBM 和 Meta

IBM:Krish Agarwal、Rishi Astra、Adnan Hoque、Mudhakar Srivatsa、Raghu Ganti
Meta:Less Wright、Sijia Chen

量化是一种通过压缩模型权重和执行较低精度数据类型计算(速度更快)来提高模型推理速度的方法。然而,由于离群值的存在,量化可能导致精度损失。最近的工作,如 QuaRotSpinQuantFlashAttention-3,引入了增加 LLM 中 INT4、INT8 和 FP8 量化数值精度的方法。这些方法依赖于哈达玛变换。在本博客中,我们介绍 HadaCore,这是一款在 NVIDIA A100 和 H100 GPU 上实现业界领先性能的哈达玛变换 CUDA 核。与 Dao AI Lab 的快速哈达玛变换核相比,我们的核实现了 1.1–1.4 倍(在 A100 上)和 1.0–1.3 倍(在 H100 上)的加速,峰值增益分别为 3.5 倍和 3.6 倍。我们利用了硬件感知的工作分解,该分解受益于 Tensor Core 加速,同时保持量化误差减少。

Figure 1: Speedup of HadaCore vs Dao AI Hadamard CUDA kernel. A peak gain of 3.46x on the A100 is achieved using 128 rotation by 8.4M elements.

图 1:HadaCore 与 Dao AI 哈达玛 CUDA 核的加速比。在 A100 上使用 128 大小的旋转操作,对 840 万个元素实现 3.46 倍的峰值增益。

HadaCore 核已公开可用

背景

QuaRotSpinQuant 都提出了提高 LLM 中 INT4 和 INT8 量化数值精度的方法。这两种方法都旋转模型激活值,因为旋转在统计学上可能减少离群值的幅值,因为它将极端值“分散”到其他(不那么极端的)维度中,并且旋转也是一种易于使用旋转矩阵逆运算的可逆操作。这些方法还可以提高 FP8 推理精度,例如在 FlashAttention-3 中。

Figure 2. Transformer block showing online (red) and offline rotations (blue) in QuaRot

图 2. Transformer 块显示 QuaRot 中的在线(红色)和离线旋转(蓝色)

应用这些旋转矩阵会引入模型运行时开销,如图 2 所示的在线操作。这些旋转可以通过矩阵乘法应用,但增加的开销会削弱量化的好处。因此,QuaRot 和 SpinQuant 选择使用 Walsh-Hadamard 矩阵,这是一种特殊的旋转矩阵,可以使用快速 Walsh-Hadamard 变换算法比矩阵乘法更快地应用。HadaCore 是该算法针对支持 Tensor Core 的 NVIDIA GPU 的优化实现。

Tensor Core 加速的哈达玛变换

HadaCore 利用 NVIDIA Tensor Core,这是 NVIDIA GPU 上针对矩阵乘法优化的专用计算单元。为了实现这一点,我们的核对快速 Walsh-Hadamard 算法进行了硬件感知的工作分解。这种工作分解确保我们可以利用在 Tensor Core 芯片上执行的 MMA PTX 指令。HadaCore 对输入数据的块应用 16x16 哈达玛变换。然后可以使用 mma.m16n8k16 指令将计算卸载到 FP16 Tensor Core。下面显示了 HadaCore 的 warp 级并行。

Figure 3: HadaCore Parallelization, 1x256 vectors (rows) being rotated by a size 256 Hadamard.

图 3:HadaCore 并行化,1x256 向量(行)由大小为 256 的哈达玛变换旋转。

我们使用 warp 级 Tensor Core 操作并行处理 256 个元素的片段,以实现高达 256 大小的哈达玛变换。对于更大尺寸,我们在 warp 之间混洗数据并重复。

微基准测试

我们在 NVIDIA H100 和 A100 GPU 上对 HadaCore 与 Dao AI Lab 哈达玛核进行了基准测试,测试了不同的哈达玛和输入张量大小。

Figure 4:  HadaCore Kernel Speedup on NVIDIA A100 over Dao AI Lab Fast Hadamard Kernel

图 4:HadaCore 核在 NVIDIA A100 上相对于 Dao AI Lab 快速哈达玛核的加速比

Color coded Speedup Table for NVIDIA A100, Green = Speedup over Baseline

NVIDIA A100 的颜色编码加速表,绿色 = 相对于基线的加速

Figure 5:  HadaCore Kernel Speedup on NVIDIA H100 over Dao AI Lab Fast Hadamard Kernel

图 5:HadaCore 核在 NVIDIA H100 上相对于 Dao AI Lab 快速哈达玛核的加速比

Color coded Speedup Table for NVIDIA H100, Green = Speedup over Baseline

NVIDIA H100 的颜色编码加速表,绿色 = 相对于基线的加速

我们的图表显示了随着输入张量大小(标记为元素数量)增加的加速比。元素数量是我们要旋转的目标矩阵中的元素数量。例如,在多头注意力中

查询 (Q)、键 (K) 和值 (V) 张量是大小为

(batch_size, seq_len, n_heads, head_dim) 的 4D 张量

大小为 head_dim 的哈达玛矩阵应用于这些激活张量,因此我们称之为使用大小为 head_dim 的哈达玛变换,元素数量为

batch_size*seq_len*n_heads*head_dim。

注意力块中查询旋转的常见元素数量

模型 \ Token Prefill Decoding
Llama-2 70b 33,554,432 个元素
128 哈达玛大小
(1 batch * 64 heads * 4096 tokens * 每个头每个 token 的 128 维嵌入)
8192 个元素
128 哈达玛大小
(1 batch * 64 heads * 1 token * 每个头每个 token 的 128 维嵌入)
Llama-3 8b 33,554,432 个元素
128 哈达玛大小
(1 batch * 32 heads * 8192 tokens * 每个头每个 token 的 128 维嵌入)
4,096 个元素
128 哈达玛大小
(1 batch * 32 heads * 1 token * 每个头每个 token 的 128 维嵌入)

HadaCore 在 A100 上相对于 Dao AI Lab 的快速哈达玛核实现了 1.1–1.4 倍的加速,在 H100 上实现了 1.0–1.3 倍的加速,峰值增益分别为 3.5 倍和 3.6 倍。对于 H100 上的较小尺寸,HadaCore 的增益有所下降。对于未来工作,我们计划结合使用 Hopper 特有功能(如 TMA 和 WGMMA)来提高 H100 的性能。

MMLU 基准测试

我们在 FlashAttention 计算以 FP8 格式执行的 Llama 3.1-8B 推理工作负载上评估了 MMLU 分数。更新一代的 NVIDIA Hopper GPU 配备了 FP8 Tensor Core,与 FP16 相比,提供了可观的计算增益。

我们的结果表明,当与 FP8 FlashAttention 等优化结合使用时,HadaCore 在保持精度方面的益处。

格式 方法 Llama3.1-8B
平均 5-Shot MMLU 精度
Q, K, V: FP16
FlashAttention: FP16
N/A 65.38
Q, K, V: FP16
FlashAttention: FP8
无哈达玛 64.40
Q, K, V: FP8
FlashAttention: FP8
HadaCore 65.09
Q, K, V: FP8
FlashAttention: FP8
Dao AI 快速哈达玛核 65.45

表 1:Llama3.1 8B 在 FP16 基线和使用哈达玛变换的 FP8 注意力下的 MMLU 分数,比较了显式哈达玛矩阵乘法实现与 HadaCore(越高越好

从上述 MMLU 分数可以看出,对于使用 FP8 注意力的 Llama3.1-8B 推理,HadaCore 改进了在较低精度下计算注意力引入的量化误差。

结论

我们展示了通过将快速 Walsh-Hadamard 算法移至利用 Tensor Core 加速的 CUDA 核所实现的加速,该核在 NVIDIA A100 和 H100 上相对于 Dao AI 快速哈达玛核分别实现了 3.5 倍和 3.6 倍的峰值加速。

此外,我们在 MMLU 基准测试中表明,使用 HadaCore 进行旋转与快速哈达玛核保持了相似的量化误差减少,同时提供了计算加速。

未来工作

我们计划实现我们核的 Triton 版本,并尝试更高级的技术,例如核融合,以支持融合哈达玛变换和量化。此外,我们计划扩展我们的核以支持 BF16 Tensor Core 计算。