博客

使用 Triton Kernels 加速 Llama3 FP8 推理

作者: 2024年5月1日2024年11月13日暂无评论

1.0 摘要

我们推出了一个名为 TK-GEMM 的优化 Triton FP8 GEMM(通用矩阵乘法)内核,它利用了 SplitK 并行化技术。针对小批量(Small batch size)推理场景,在 NVIDIA H100 GPU 上进行 Llama3-70B 推理任务时,TK-GEMM 相比基础 Triton matmul 实现速度提升高达 1.94 倍,相比 cuBLAS FP8 提升 1.87 倍,相比 cuBLAS FP16 提升 1.71 倍

TK-GEMM Speedup over PyTorch (calling cuBLAS) for Llama3-70B Attention Layer Matrix Shapes (N=K=8192)

图 1. TK-GEMM 针对 Llama3-70B 注意力层矩阵形状(N=K=8192)相较于 PyTorch(调用 cuBLAS)的速度提升

在本博客中,我们将介绍如何使用 Triton 设计针对 FP8 推理的优化内核,并将其针对 Llama3-70B 推理进行调优。我们将涵盖 Hopper 架构 GPU (SM90) 支持的新数据类型 FP8(8 位浮点数)、Triton 支持的关键 SM90 特性,以及我们如何修改并行化策略,以便在内存受限(推理)的问题规模下最大化内存吞吐量。

我们还专门开辟了一个章节介绍 CUDA 图(CUDA Graphs),这是一项重要的技术,有助于实现内核级的性能提升,并使希望在生产环境中使用 Triton 内核的开发者能够获得额外的性能增益。

代码仓库地址:https://github.com/pytorch-labs/applied-ai

2.0 FP8 数据类型

FP8 数据类型是由 Nvidia、Arm 和 Intel 联合推出的,旨在作为 16 位浮点类型的继任者。由于位宽仅为后者的一半,它在 Transformer 网络中相比前代产品具有显著提高吞吐量的潜力。FP8 数据类型包含两种格式:

E4M3(4 位指数和 3 位尾数)。能够存储 +/- 448 和 NaN。
E5M2(5 位指数和 2 位尾数)。能够存储 +/- 57,334、NaN 和 Inf。

BF16, FP16, FP8 E4M3 and FP8 E5M2

上图:BF16、FP16、FP8 E4M3 和 FP8 E5M2。
为了展示精度差异,下表展示了每种格式中最接近 0.3952 的表示。
图片来源:Nvidia

我们由于 E4M3 精度更高,在推理和训练前向传播中使用它;而由于 E5M2 动态范围更大,在训练反向传播中使用它。Nvidia 设计的 H100 FP8 Tensor Core 可提供 3958 TFLOPS 的峰值算力,是 FP16 Tensor Core FLOPS 的 2 倍

我们在设计 Triton 内核时充分考虑了这些硬件创新。在博客的后续部分,我们将讨论如何利用这些特性,并验证 Triton 编译器确实在使用它们。

3.0 Triton 对 Hopper 的支持与 FP8 Tensor Core 指令

Hopper GPU 架构增加了以下新特性,预计将加速 FP8 GEMM:

  • TMA(张量内存加速器)硬件单元
  • WGMMA(Warp 组矩阵乘加指令)
  • 线程块集群(Threadblock Clusters)

Triton 目前利用了其中的一项特性:wgmma 指令,而 PyTorch(调用 cuBLAS)则利用了全部三项特性,这使得 PyTorch 的性能提升更令人印象深刻。为了充分利用 Hopper FP8 Tensor Core,尽管旧的 mma.sync 指令仍受支持,但 wgmma 是必不可少的。

mma 和 wgmma 指令之间的关键区别在于:不是由 1 个 CUDA Warp 负责一个输出分片(shard),而是由整个 Warp 组(4 个 CUDA Warp)异步地为输出分片做出贡献。

为了观察该指令在实践中的表现,并验证我们的 Triton 内核确实使用了这一特性,我们使用 Nsight Compute 对 PTX 和 SASS 汇编代码进行了分析。

PTX Assembly

图 2. PTX 汇编

该指令在 SASS 中被进一步降级为 QGMMA 指令。

SASS Assembly

图 3. SASS 汇编

这两条指令都表明我们正在将两个 FP8 E4M3 输入张量相乘,并以 F32 格式进行累加,这证实了 TK-GEMM 内核确实在使用 FP8 Tensor Core,并且降级过程是正确的。

4.0 SplitK 工作分解

TK-GEMM vs Base Triton GEMM TFLOPS for M = 1-64

图 4. M = 1-64 时 TK-GEMM 与基础 Triton GEMM 的 TFLOPS 对比

基础的 Triton FP8 GEMM 实现对于小 M 情况(即矩阵乘法 A(MxN) x B(NxK) 中 M < N, K)表现不佳。为了优化这种矩阵配置,我们应用了 SplitK 工作分解,而不是基础 Triton 内核中采用的数据并行分解。这极大地改善了小 M 情况下的延迟。

作为背景知识,SplitK 沿着 K 维度启动额外的线程块来计算部分输出和。每个线程块的部分结果随后通过原子规约(atomic reduction)求和。这允许更细粒度的工作分解,从而带来性能提升。有关 SplitK 的更多详细信息,请参阅我们的 arXiv 论文

在仔细针对 Llama3-70B 问题规模调优了我们内核的其他相关超参数(如切片大小、Warp 数量和流水线阶段数)后,我们相比 Triton 基础实现实现了最高 1.94 倍的速度提升。有关超参数调优的更全面介绍,请参阅我们的博客

NCU profiler times for TK-GEMM under varying batch sizes, and compared with PyTorch (calling cuBLAS) FP8 and FP16.

上图:在不同 Batch Size 下 TK-GEMM 的 NCU 分析器耗时,并与 PyTorch(调用 cuBLAS)FP8 和 FP16 进行了比较。

注意,从 M=32 开始,cuBLAS FP8 内核开始超越 TK-GEMM。对于 M >= 32,我们推测当前找到的超参数并非最优,因此需要进行另一组实验来确定针对中等规模 M 的最优参数。

5.0 使用 CUDA 图实现端到端加速

为了在端到端设置中实现这些速度提升,我们必须同时考虑内核执行时间(GPU 时长)和总耗时(CPU+GPU 时长)。众所周知,手工编写的 Triton 内核(相对于 torch compile 生成的内核)存在较高的内核启动延迟。如果我们使用 torch profiler 跟踪 TK-GEMM 内核,我们可以查看到 CPU 端的调用堆栈,从而精确定位导致减速的原因。

CPU Launch Overhead: 2.413ms

图 5. CPU 启动开销:2.413ms

从上图可以看出,我们优化内核的大部分总耗时被 JIT(即时编译)开销占据。为了解决这个问题,我们可以使用 CUDA 图。

CUDA Graphs Visualization

图 6. CUDA 图可视化
图片来源:PyTorch

核心思想是:与其进行多次内核启动,不如创建并实例化一个图(一次性成本),然后提交该图的实例进行执行。为了说明这一点,我们模拟了一个 Llama3-70B 注意力层。如下图所示(使用 Nsight Systems 生成),由于 CPU 内核启动开销,每次 GEMM 之间的间隔为 165us,而实际矩阵乘法仅花费 12us。这意味着在注意力层中,GPU 有 92% 的时间处于空闲状态,未进行任何工作。

Simulated Llama3-70B Attention Layer with TK-GEMM

图 7. 使用 TK-GEMM 的模拟 Llama3-70B 注意力层

为了展示 CUDA 图的影响,我们在示例注意力层中为 TK-GEMM 内核创建了一个图并进行了重放。如下图所示,内核执行之间的间隔减少到了 6.65us。

Simulated Llama3-70B Attention Layer with TK-GEMM and CUDA Graphs

图 8. 使用 TK-GEMM 和 CUDA 图的模拟 Llama3-70B 注意力层

在实践中,这种优化将使 Llama3-70B 中单个注意力层的速度比在不使用 CUDA 图的模型中原生使用 TK-GEMM 快 6.4 倍

6.0 未来的潜在优化路径

TMA Hardware Unit

图 9. TMA 硬件单元
图片来源:Nvidia

Nvidia H100 具有一个 TMA 硬件单元。专用的 TMA 单元释放了寄存器和线程以执行其他工作,因为地址生成完全由 TMA 处理。对于内存受限的问题规模,当 Triton 启用对该特性的支持时,这将提供进一步的提升。

Tensor Core Utilization (Arrows Indicate Degrees of Freedom)

图 10. Tensor Core 利用率(箭头表示自由度)

为了确定我们对 Tensor Core 的利用程度,我们可以分析 Roofline 模型图。请注意,正如小 M 情况下预期的那样,我们处于内存受限区域。要改善内核延迟,我们可以通过利用数据局部性和其他循环优化来提高算术强度(对于固定问题规模而言),或者增加内存吞吐量。这需要针对 FP8 数据类型以及我们在 FP8 推理中预期的特定问题规模特征,采用更优的并行算法。

DRAM Throughput Circled, 1.65TB/s vs Peak 3.35TB/s on H100 (M=16, N=8192, K=8192)

图 11. DRAM 吞吐量(圈出部分),H100 峰值为 3.35TB/s,实测为 1.65TB/s(M=16, N=8192, K=8192)

最后,我们可以看到我们在 NVIDIA H100 上仅达到了峰值 DRAM 吞吐量的约 50%。高性能 GEMM 内核通常达到峰值吞吐量的 70-80%。这意味着仍有很大的提升空间,上述提到的技术(循环展开、优化并行化)对于获得额外增益是必要的。

7.0 未来工作

对于未来的研究,我们希望探索 CUTLASS 3.x 和 CuTe,以便对 Hopper 特性进行更直接的控制,特别是在获得直接的 TMA 控制和探索乒乓(pingpong)架构方面——这些技术在 FP8 GEMM 中已经显示出了有前景的结果。