作者:Adnan Hoque, Less Wright, Chih Chieh Yang

1.0 摘要

我们提出了一个优化的 Triton FP8 GEMM(通用矩阵乘法)内核 TK-GEMM,它利用了 SplitK 并行化。对于小批量推理,TK-GEMM 在 NVIDIA H100 GPU 上针对 Llama3-70B 推理问题规模提供了相对于基础 Triton matmul 实现高达 1.94 倍 的速度提升,相对于 cuBLAS FP8 提升 1.87 倍,相对于 cuBLAS FP16 提升 1.71 倍

TK-GEMM Speedup over PyTorch (calling cuBLAS) for Llama3-70B Attention Layer Matrix Shapes (N=K=8192)

图 1. TK-GEMM 相对于 PyTorch(调用 cuBLAS)在 Llama3-70B 注意力层矩阵形状(N=K=8192)上的加速比

在这篇博客中,我们将介绍如何使用 Triton 设计一个用于 FP8 推理的优化内核,并针对 Lama3-70B 推理进行了调优。我们将介绍 FP8(8 位浮点),这是 Hopper 代 GPU (SM90) 支持的新数据类型,Triton 支持的关键 SM90 特性,以及我们如何修改并行化以最大化内存密集型(推理)问题规模的内存吞吐量。

我们还专门用了一个章节介绍 CUDA Graphs,这是一项重要的技术,它将有助于实现内核级别的加速,并使希望在生产环境中使用 Triton 内核的开发者获得额外的性能提升。

仓库和代码可在以下地址获取:https://github.com/pytorch-labs/applied-ai

2.0 FP8 数据类型

FP8 数据类型由 Nvidia、Arm 和 Intel 联合引入,是 16 位浮点类型的后续版本。其位宽减半,在 Transformer 网络中相对于前代有潜力提供显著的吞吐量提升。FP8 数据类型包含 2 种格式

E4M3(4 位指数和 3 位尾数)。能够存储 +/- 448 和 nan。
E5M2(5 位指数和 2 位尾数)。能够存储 +/- 57,334、nan 和 inf。

BF16, FP16, FP8 E4M3 and FP8 E5M2

上图: BF16、FP16、FP8 E4M3 和 FP8 E5M2。
为了展示精度差异,每种格式中与 0.3952 最接近的表示如下图所示。
图片来源:Nvidia

我们在推理和训练前向传播中使用 E4M3,因为它具有更高的精度;在训练后向传播中使用 E5M2,因为它具有更高的动态范围。Nvidia 已将其 H100 FP8 Tensor Core 设计为提供 3958 TFLOPS 的峰值性能,是 FP16 Tensor Core FLOPS 的 2 倍

我们在设计 Triton 内核时考虑了这些硬件创新,在本博客的其余部分,我们将讨论如何利用这些特性并验证 Triton 编译器是否确实利用了这些特性。

3.0 Triton Hopper 支持和 FP8 Tensor Core 指令

Hopper GPU 架构增加了以下 新特性,我们预计这些特性将加速 FP8 GEMM。

  • TMA (Tensor Memory Accelerator) 硬件单元
  • WGMMA (Warp Group Matrix Multiply-Accumulate Instruction)
  • 线程块集群

Triton 目前利用了其中一个特性,即 wgmma 指令,而 PyTorch(调用 cuBLAS)则利用了全部 3 个特性,这使得这些加速更加令人印象深刻。为了充分利用 Hopper FP8 Tensor Core,即使仍然支持旧的 mma.sync 指令,wgmma 也是必需的。

mma 和 wgmma 指令之间的关键区别在于,不是由 1 个 CUDA warp 负责一个输出分片,而是由整个 warp 组(4 个 CUDA warp)异步地为一个输出分片做出贡献。

为了实际了解这个指令的样子,并验证我们的 Triton 内核确实利用了这个特性,我们使用 nsight compute 分析了 PTX 和 SASS 汇编代码。

PTX Assembly

图 2. PTX 汇编代码

这个指令在 SASS 中进一步被降级为 QGMMA 指令。

SASS Assembly

图 3. SASS 汇编代码

这两个指令都告诉我们正在将两个 FP8 E4M3 输入张量相乘并在 F32 中进行累加,这证实了 TK-GEMM 内核正在利用 FP8 Tensor Core,并且降级过程是正确的。

4.0 SplitK 工作分解

TK-GEMM vs Base Triton GEMM TFLOPS for M = 1-64

图 4. M = 1-64 时 TK-GEMM 与基础 Triton GEMM 的 TFLOPS 对比

基础 Triton FP8 GEMM 实现对于小 M 范围的性能不佳,其中矩阵乘法 A (MxN) x B (NxK) 的 M < N, K。为了针对这种矩阵特性进行优化,我们应用了 SplitK 工作分解,而不是基础 Triton 内核中的数据并行分解。这大大改善了小 M 范围的延迟。

背景介绍:SplitK 沿 k 维度启动额外的线程块来计算部分输出和。然后使用原子归约将每个线程块的部分结果求和。这允许更细粒度的工作分解,从而带来性能提升。有关 SplitK 的更多详细信息,请参阅我们的 arxiv 论文

在仔细调整了内核的其他相关超参数,如瓦块大小、warp 数量和流水线阶段数量以匹配 Llama3-70B 问题规模后,我们成功地比 Triton 基础实现 提升了 **1.94 倍** 的速度。有关超参数调优的更全面介绍,请参阅我们的博客

NCU profiler times for TK-GEMM under varying batch sizes, and compared with PyTorch (calling cuBLAS) FP8 and FP16.

上图TK-GEMM 在不同批量大小下的 NCU 性能分析器时间,并与 PyTorch(调用 cuBLAS)FP8 和 FP16 进行比较。

注意,从 M=32 开始,cuBLAS FP8 内核开始超越 TK-GEMM。对于 M >= 32,我们怀疑我们找到的超参数不是最优的,因此需要进行另一组实验来确定中等大小 M 范围的最佳参数。

5.0 使用 CUDA Graphs 实现端到端加速

为了在端到端场景中实现这些加速,我们必须同时考虑内核执行时间(GPU 时长)以及总耗时(CPU+GPU 时长)。Triton 内核(手写,而非 torch compile 生成)已知存在内核启动延迟高的问题。如果我们使用 torch profiler 跟踪 TK-GEMM 内核,我们可以看到 CPU 端的调用栈,从而精确定位导致速度变慢的原因。

CPU Launch Overhead: 2.413ms

图 5. CPU 启动开销:2.413毫秒

从上图可以看出,我们优化内核的大部分总耗时被 JIT(即时)编译开销所占据。为了解决这个问题,我们可以使用 CUDA Graphs。

CUDA Graphs Visualization

图 6. CUDA Graphs 可视化
图片来源:PyTorch

关键思想是,与其进行多次内核启动,不如创建一个图并进行实例化(一次性成本),然后提交该图的实例进行执行。为了说明这一点,我们模拟了 Llama3-70B 注意力层。如下图所示(使用 nsight systems 生成),由于 CPU 内核启动开销,每个 GEMM 之间的时间间隔为 165微秒,而实际 matmul 花费的时间为 **12微秒**。这意味着在注意力层中,**92%** 的时间 GPU 是空闲的,没有进行任何工作。

Simulated Llama3-70B Attention Layer with TK-GEMM

图 7. 使用 TK-GEMM 模拟 Llama3-70B 注意力层

为了展示 CUDA Graphs 的影响,我们然后创建了玩具注意力层中 TK-GEMM 内核的图,并重放了该图。如下图所示,内核执行之间的间隔减少到 6.65微秒。

Simulated Llama3-70B Attention Layer with TK-GEMM and CUDA Graphs

图 8. 使用 TK-GEMM 和 CUDA Graphs 模拟 Llama3-70B 注意力层

实际上,与在没有 CUDA Graphs 的模型中天真地使用 TK-GEMM 相比,这项优化将使 Llama3-70B 中的单个注意力层提速 6.4 倍

6.0 潜在的未来优化方向

TMA Hardware Unit

图 9. TMA 硬件单元
图片来源:Nvidia

Nvidia H100 配备了 TMA 硬件单元。专用的 TMA 单元释放了寄存器和线程用于执行其他工作,因为地址生成完全由 TMA 处理。对于内存密集型问题规模,当 Triton 启用对该特性的支持时,可以提供进一步的增益。

Tensor Core Utilization (Arrows Indicate Degrees of Freedom)

图 10. Tensor Core 利用率(箭头表示自由度)

为了确定我们对 Tensor Core 的利用率如何,我们可以分析屋顶线图。注意,正如小 M 预期那样,我们处于内存限制区域。为了提高内核延迟,我们可以增加算术强度,这在固定问题规模下只能通过利用数据局部性和其他循环优化来实现,或者增加内存吞吐量。这需要一个针对 FP8 数据类型以及我们期望在 FP8 推理中看到的这类问题规模特性进行优化的并行算法。

DRAM Throughput Circled, 1.65TB/s vs Peak 3.35TB/s on H100 (M=16, N=8192, K=8192)

图 11. 圈出的 DRAM 吞吐量,H100 上为 1.65TB/s vs 峰值 3.35TB/s (M=16, N=8192, K=8192)

最后,我们可以看到在 NVIDIA H100 上,我们仅达到峰值 DRAM 吞吐量的约 50%。高性能 GEMM 内核通常能达到峰值吞吐量的约 70-80%。这意味着仍有很大的改进空间,并且需要上述技术(循环展开、优化并行化)来获得额外的增益。

7.0 未来工作

对于未来的研究,我们希望探索 CUTLASS 3.x 和 CuTe,以便更直接地控制 Hopper 特性,特别是在直接获取 TMA 控制和探索 pingpong 架构方面,后者在 FP8 GEMM 方面已显示出有希望的结果。