跳转到主要内容
博客

使用 Triton Kernels 加速 Llama3 FP8 推理

作者: 2024 年 5 月 1 日2024 年 11 月 13 日暂无评论

1.0 概述

我们提出了一种优化的 Triton FP8 GEMM(通用矩阵乘法)内核 TK-GEMM,它利用了 SplitK 并行化。对于小批量推理,在 NVIDIA H100 GPU 上,针对 Llama3-70B 推理问题规模,TK-GEMM 比基础 Triton 矩阵乘法实现提速高达 1.94 倍,比 cuBLAS FP8 提速 1.87 倍,比 cuBLAS FP16 提速 1.71 倍

TK-GEMM Speedup over PyTorch (calling cuBLAS) for Llama3-70B Attention Layer Matrix Shapes (N=K=8192)

图 1. TK-GEMM 在 Llama3-70B 注意力层矩阵形状(N=K=8192)下相对于 PyTorch(调用 cuBLAS)的加速效果

在这篇博客中,我们将介绍如何使用 Triton 设计优化的 FP8 推理内核,并针对 Lama3-70B 推理进行调整。我们将介绍 FP8(8 位浮点),一种 Hopper 代 GPU (SM90) 支持的新数据类型,Triton 支持的关键 SM90 功能,以及我们如何修改并行化以最大化内存绑定(推理)问题规模的内存吞吐量。

我们还专门讨论了 CUDA 图,这是一项重要的技术,有助于实现内核级别的加速,并使希望在生产环境中使用 Triton 内核的开发人员获得额外的性能提升。

代码仓库和代码可在以下地址获取:https://github.com/pytorch-labs/applied-ai

2.0 FP8 数据类型

FP8 数据类型由 Nvidia、Arm 和 Intel 联合 引入,作为 16 位浮点类型的后继。其位数为一半,有可能比其前代产品为 Transformer 网络提供显著的吞吐量改进。FP8 数据类型包含 2 种格式:

E4M3 (4 位指数和 3 位尾数)。能够存储 +/- 448 和 nan。
E5M2 (5 位指数和 2 位尾数)。能够存储 +/- 57,334、nan 和 inf。

BF16, FP16, FP8 E4M3 and FP8 E5M2

上图: BF16、FP16、FP8 E4M3 和 FP8 E5M2。
为了显示精度差异,每种格式中显示了最接近 0.3952 的表示。
图片来源:Nvidia

我们在推理和前向传播训练中使用 E4M3,因为它具有更高的精度,而在训练反向传播中使用 E5M2,因为它具有更高的动态范围。Nvidia 设计的 H100 FP8 Tensor Core 可提供 3958 TFLOPS 的峰值,是 FP16 Tensor Core FLOPS 的 2 倍

我们设计 Triton 内核时考虑了这些硬件创新,在本博客的其余部分,我们将讨论利用和验证 Triton 编译器确实使用了这些功能的方法。

3.0 Triton Hopper 支持和 FP8 Tensor Core 指令

Hopper GPU 架构增加了以下 新功能,我们预计这些功能将加速 FP8 GEMM。

  • TMA (Tensor Memory Accelerator) 硬件单元
  • WGMMA (Warp Group Matrix Multiply-Accumulate Instruction)
  • 线程块集群

Triton 目前利用了其中一项功能,即 wgmma 指令,而 PyTorch(调用 cuBLAS)则利用了所有 3 项功能,这使得这些加速更加令人印象深刻。要充分利用 Hopper FP8 Tensor Core,wgmma 是必需的,即使仍然支持较旧的 mma.sync 指令。

mma 和 wgmma 指令之间的关键区别在于,wgmma 指令不是由 1 个 CUDA warp 负责一个输出分片,而是由整个 warp 组(4 个 CUDA warp)异步地为一个输出分片做出贡献。

为了实际了解此指令的外观,并验证我们的 Triton 内核确实利用了此功能,我们使用 nsight compute 分析了 PTX 和 SASS 汇编。

PTX Assembly

图 2. PTX 汇编

此指令在 SASS 中进一步降级为 QGMMA 指令。

SASS Assembly

图 3. SASS 汇编

这两条指令都告诉我们,我们正在将两个 FP8 E4M3 输入张量相乘并在 F32 中累加,这证实了 TK-GEMM 内核正在利用 FP8 张量核并且降级正在正确完成。

4.0 SplitK 工作分解

TK-GEMM vs Base Triton GEMM TFLOPS for M = 1-64

图 4. TK-GEMM 与基础 Triton GEMM TFLOPS,M = 1-64

基础 Triton FP8 GEMM 实现对于小 M 范围的性能 不佳,其中对于 A (MxN) x B (NxK) 的矩阵乘法,M < N, K。为了优化这种类型的矩阵配置文件,我们采用了 SplitK 工作分解,而不是基础 Triton 内核中的数据并行分解。这大大改善了小 M 范围的延迟。

背景知识:SplitK 沿着 k 维度启动额外的线程块来计算部分输出和。每个线程块的部分结果随后使用原子规约求和。这允许更细粒度的工作分解,从而提高性能。有关 SplitK 的更多详细信息,请参见我们的 arxiv 论文

在针对 Llama3-70B 问题规模仔细调整了内核的其他相关超参数,例如瓦片大小、warp 数量和流水线阶段数之后,我们能够比 Triton 基础实现 提速高达 1.94 倍。有关超参数调整的更全面的介绍,请参见我们的 博客

NCU profiler times for TK-GEMM under varying batch sizes, and compared with PyTorch (calling cuBLAS) FP8 and FP16.

上图TK-GEMM 在不同批处理大小下的 NCU 分析器时间,并与 PyTorch (调用 cuBLAS) FP8 和 FP16 进行比较。

请注意,从 M=32 开始,cuBLAS FP8 内核开始优于 TK-GEMM。对于 M >= 32,我们怀疑我们找到的超参数不是最优的,因此需要进行另一组实验来确定中等 M 范围的最优参数。

5.0 CUDA 图实现端到端加速

为了在端到端设置中实现这些加速,我们必须同时考虑内核执行时间(GPU 持续时间)和墙钟时间(CPU+GPU)持续时间。Triton 内核(手写,而不是 torch 编译生成)已知会遇到高内核启动延迟。如果我们使用 torch profiler 跟踪 TK-GEMM 内核,我们可以在 CPU 侧查看调用堆栈,以精确找出导致减速的原因。

CPU Launch Overhead: 2.413ms

图 5. CPU 启动开销:2.413 毫秒

从上图可以看出,我们优化内核的大部分墙钟时间都由 JIT(即时)编译开销主导。为了解决这个问题,我们可以使用 CUDA 图。

CUDA Graphs Visualization

图 6. CUDA 图可视化
图片来源:PyTorch

核心思想是,我们不再进行多次内核启动,而是可以创建并实例化一个图(一次性成本),然后提交该图的实例进行执行。为了说明这一点,我们模拟了一个 Llama3-70B 注意力层。如下图所示,使用 nsight systems 生成,每个 GEMM 之间的时间为 165us,而实际矩阵乘法花费的时间为 12us,这是由于 CPU 内核启动开销造成的。这意味着在一个注意力层中,GPU 92% 的时间处于空闲状态,没有做任何工作。

Simulated Llama3-70B Attention Layer with TK-GEMM

图 7. 模拟 Llama3-70B 注意力层与 TK-GEMM

为了展示 CUDA 图的影响,我们接着在玩具注意力层中创建了 TK-GEMM 内核的图并重放了该图。下面我们可以看到内核执行之间的间隔减少到 6.65us。

Simulated Llama3-70B Attention Layer with TK-GEMM and CUDA Graphs

图 8. 模拟 Llama3-70B 注意力层与 TK-GEMM 和 CUDA 图

在实践中,这种优化将使 Llama3-70B 中单个注意力层提速 6.4 倍,而无需在没有 CUDA 图的模型中天真地使用 TK-GEMM。

6.0 潜在的未来优化路径

TMA Hardware Unit

图 9. TMA 硬件单元
图片来源:Nvidia

Nvidia H100 具有 TMA 硬件单元。专用的 TMA 单元可释放寄存器和线程以执行其他工作,因为地址生成完全由 TMA 处理。对于内存受限的问题规模,当 Triton 支持此功能时,这可以提供进一步的收益。

Tensor Core Utilization (Arrows Indicate Degrees of Freedom)

图 10. Tensor Core 利用率(箭头表示自由度)

为了确定我们对 Tensor Core 的利用程度,我们可以分析屋脊线图。请注意,正如预期的那样,对于小 M,我们处于内存受限区域。为了改善内核延迟,我们可以通过利用数据局部性和其他循环 优化 来增加算术强度,或者增加内存吞吐量。这需要一个更优化的并行算法,专门针对 FP8 数据类型以及我们期望在 FP8 推理中看到的问题大小特性。

DRAM Throughput Circled, 1.65TB/s vs Peak 3.35TB/s on H100 (M=16, N=8192, K=8192)

图 11. DRAM 吞吐量圈出,H100 上为 1.65TB/s,峰值为 3.35TB/s (M=16, N=8192, K=8192)

最后,我们可以看到 NVIDIA H100 上的 DRAM 吞吐量仅达到峰值的 50% 左右。高性能 GEMM 内核通常能达到峰值吞吐量的 70-80%。这意味着仍有很大的改进空间,并且需要上述技术(循环展开、优化并行化)才能获得额外收益。

7.0 未来工作

对于未来的研究,我们希望探索 CUTLASS 3.x 和 CuTe,以实现对 Hopper 功能更直接的控制,特别是在获得直接 TMA 控制和探索乒乓架构方面,这在 FP8 GEMM 中已显示出有希望的结果。