作者:Adnan Hoque, Less Wright, Raghu Ganti 和 Mudhakar Srivatsa

在本篇博客中,我们将讨论我们如何使用 OpenAI 的 Triton 语言实现流行的 LLM 模型(例如 Meta 的 Llama3-8B 和 IBM 的 Granite-8B Code)的 FP16 推理,其中 100% 的计算都是使用 OpenAI 的 Triton 语言完成的。
对于使用基于 Triton 核的模型的单 token 生成时间,在 Nvidia H100 GPU 上,我们能够达到相对于以 CUDA 核为主的工作流程的 0.76-0.78 倍性能(针对 Llama 和 Granite 模型);在 Nvidia A100 GPU 上,则能达到 0.62-0.82 倍性能。

为何要探索使用 100% 的 Triton?Triton 为 LLM 在不同类型的 GPU(NVIDIA、AMD,未来还有 Intel 和其他基于 GPU 的加速器)上运行提供了途径。它还在 Python 中提供了更高层次的抽象来编程 GPU,使我们能够比使用特定供应商 API 更快地编写高性能核。在本博客的其余部分,我们将分享如何实现无需 CUDA 的计算,对单个核进行微基准测试以进行比较,并讨论如何进一步改进未来的 Triton 核以弥合差距。

图 1:Llama3-8B 和 Granite-8B 的 Triton 和 CUDA 变体在 NVIDIA H100 和 A100 上的推理吞吐量基准测试
设置:批量大小 = 2,输入序列长度 = 512,输出序列长度 = 256

2.0 Transformer 模块的构成

我们首先分析基于 Transformer 的模型中发生的计算。下图显示了典型的 Transformer 模块的“核”。

图 2:Transformer 模块的核心核

Llama3 架构的核心操作总结如下列表

  1. RMSNorm
  2. 矩阵乘法:融合 QKV
  3. RoPE
  4. 注意力
  5. 矩阵乘法:输出投影
  6. RMSNorm
  7. 矩阵乘法:融合 Gate + Up 投影
  8. 激活函数:SiLU
  9. 逐元素乘法
  10. 矩阵乘法:Down 投影

这些操作中的每一个都在 GPU 上通过执行一个(或多个)核来计算。虽然这些核的具体实现因不同的 Transformer 模型而异,但核心操作保持不变。例如,IBM 的 Granite 8B Code 模型在 MLP 层使用了偏置,这与 Llama3 不同。这些变化确实需要修改核。典型的模型是由这些 Transformer 模块与嵌入层堆叠并连接在一起组成的。

3.0 模型推理

典型的模型架构代码通常位于由 PyTorch 启动的 python model.py 文件中。在默认的 PyTorch eager execution 模式下,这些核都通过 CUDA 执行。为了实现 Llama3-8B 和 Granite-8B 端到端推理的 100% Triton 计算,我们需要编写和集成手写的 Triton 核,并利用 torch.compile(生成 Triton 操作)。首先,我们用编译器生成的 Triton 核替换较小的操作;其次,我们用手写的 Triton 核替换更耗时和复杂的计算(例如矩阵乘法和 Flash Attention)。

Torch.compile 会自动为 RMSNorm、RoPE、SiLU 和逐元素乘法生成 Triton 核。使用 Nsight Systems 等工具,我们可以观察到这些生成的核;它们在矩阵乘法和注意力计算之间显示为微小的深绿色核。

图 3:使用 torch.compile 的 Llama3-8B 追踪,显示矩阵乘法和 Flash Attention 使用的是 CUDA 核

对于上述追踪,我们注意到在 Llama3-8B 风格的模型中,占 E2E 延迟 80% 的两个主要操作是矩阵乘法和注意力核,两者都仍然是 CUDA 核。因此,为了弥合剩余的差距,我们用手写的 Triton 核替换了 matmul 和注意力核。

4.0 Triton SplitK GEMM 核

对于线性层中的矩阵乘法,我们编写了一个自定义的 FP16 Triton GEMM(通用矩阵乘法)核,它利用了 SplitK 工作分解。我们在其他博客中曾讨论过这种并行化方法,作为加速 LLM 推理的解码部分的手段。

5.0 GEMM 核调优

为了获得最佳性能,我们使用了穷举搜索方法来调优我们的 SplitK GEMM 核。Granite-8B 和 Llama3-8B 的线性层具有以下形状:

线性层 形状 (in_features, out_features)
融合 QKV 投影 (4096, 6144)
输出投影 (4096, 4096)
融合 Gate + Up 投影 (4096, 28672)
Down 投影 (14336, 4096)

图 4:Granite-8B 和 Llama3-8B 线性层权重矩阵形状

这些线性层中的每一个都具有不同的权重矩阵形状。因此,为了获得最佳性能,必须针对每种形状配置文件对 Triton 核进行调优。在对每个线性层进行调优后,我们能够在 Llama3-8B 和 Granite-8B 上相对于未经调优的 Triton 核实现 1.20 倍的 E2E 加速。

6.0 Flash Attention 核

我们评估了一系列现有 Triton Flash Attention 核,它们具有不同的配置,包括:

  1. AMD Flash
  2. OpenAI Flash
  3. Dao AI Lab Flash
  4. XFormers Flash
  5. PyTorch FlexAttention

我们评估了这些核在 eager 模式下的文本生成质量,随后(如果能够使用标准方法通过 torch.compile 编译核)在编译模式下也进行了评估。对于核 2-5,我们注意到以下情况:

文本生成质量 Torch.compile 支持任意序列长度
AMD Flash 连贯
OpenAI Flash 不连贯 未评估。正在进行中(WIP),先调试 eager 模式下的精度问题。
Dao AI Lab Flash 不连贯 未评估。正在进行中(WIP),先调试 eager 模式下的精度问题。
Xformers FlashDecoding 在我们评估文本质量之前遇到了编译错误 WIP 否(此核针对解码进行了优化)
PyTorch FlexAttention 连贯 WIP WIP

图 5:我们尝试的不同 Flash Attention 核组合表

上表总结了我们开箱即用的观察结果。经过一些努力,我们期望可以修改核 2-5 以满足上述标准。然而,这也表明,拥有一个适用于基准测试的核通常仅仅是使其可用作端到端生产核的开始。
在后续测试中,我们选择使用 AMD Flash Attention 核,因为它可以通过 torch.compile 编译,并在 eager 模式和编译模式下都能生成可读的输出。

为了使 AMD Flash Attention 核与 torch.compile 兼容,我们必须将其定义为 torch 自定义操作。这个过程在此处详细解释。该教程链接讨论了如何封装一个简单的图像裁剪操作。然而,我们注意到封装一个更复杂的 Flash Attention 核遵循类似的过程。两步法如下:

  1. 将函数封装为 PyTorch 自定义操作

  1. 向该操作添加一个 FakeTensor 核,它根据 flash 输入张量(q、k 和 v)的形状,提供了一种计算 flash 核输出形状的方法

将 Triton Flash 核定义为自定义操作后,我们成功地将其编译用于 E2E 运行。

图 6:在替换为 Triton matmul 和 Triton Flash Attention 核后,使用 torch.compile 的 Llama3-8B 追踪

从图 5 中,我们注意到现在,在集成了 SplitK 矩阵乘法核、经过 torch op 封装的 Flash Attention 核,并运行 torch.compile 后,我们能够实现使用 100% Triton 计算核的前向传播。

7.0 端到端基准测试

我们在 NVIDIA H100 和 A100(单 GPU)上使用 Granite-8B 和 Llama3-8B 模型进行了端到端测量。我们使用两种不同的配置进行了基准测试。

Triton 核配置使用

  1. Triton SplitK GEMM
  2. AMD Triton Flash Attention

CUDA 核配置使用

  1. cuBLAS GEMM
  2. cuDNN Flash Attention - 缩放点积注意力 (SDPA)

在典型的推理设置下,我们在 eager 和 torch 编译模式下均获得了以下吞吐量和 token 间延迟:

GPU 模型 核配置 中位数延迟 (Eager) [ms/tok] 中位数延迟 (编译) [ms/tok]
H100 Granite-8B Triton 27.42 11.59
    CUDA 18.84 9.50
  Llama3-8B Triton 20.36 10.61
    CUDA 16.59 8.59
A100 Granite-8B Triton 53.44 16.88
    CUDA 37.13 14.25
  Llama3-8B Triton 44.44 17.94
    CUDA 32.45 12.96

图 7:Granite-8B 和 Llama3-8B 在 H100 和 A100 上的单 Token 生成延迟
(批量大小 = 2,输入序列长度 = 512,输出序列长度 = 256)

总而言之,Triton 模型在 H100 上可以达到 CUDA 模型性能的 78%,在 A100 上则可以达到 82%

性能差距可以通过我们在矩阵乘法和 Flash Attention 中观察到的核延迟来解释,这将在下一节讨论。

8.0 微基准测试

Triton [us] CUDA [us]
QKV 投影矩阵乘法 25 21
Flash Attention 13 8
输出投影矩阵乘法 21 17
Gate + Up 投影矩阵乘法 84 83
Down 投影矩阵乘法 58 42

图 8:Triton 与 CUDA 核延迟比较(Llama3-8B 在 NVIDIA H100 上)
输入是任意 prompt(批量大小=1,prompt = 44 序列长度),解码延迟时间

从上表(图 8)中,我们注意到以下几点:

  1. Triton 矩阵乘法核比 CUDA 慢 1.2-1.4 倍

  2. AMD 的 Triton Flash Attention 核比 CUDA SDPA 慢 1.6 倍

这些结果突显了进一步提高 GEMM 和 Flash Attention 等核心原语核性能的必要性。我们将此留待未来的研究,因为最近的工作(例如 FlashAttention-3FlexAttention)提供了更好地利用底层硬件的方法,以及我们希望能够在此基础上构建 Triton 路径以实现更大的加速。为了说明这一点,我们将 FlexAttention 与 SDPA 和 AMD 的 Triton Flash 核进行了比较。

我们正在验证 FlexAttention 的 E2E 性能。目前,Flex 的初步微基准测试结果表明,对于更长的上下文长度和解码问题形状(其中查询向量较小)显示出潜力。

图 9:FlexAttention 核在 NVIDIA H100 SXM5 80GB 上的基准测试
(批量=1,头数=32,序列长度=序列长度,头维度=128)

9.0 未来工作

未来的工作方面,我们计划探索进一步优化矩阵乘法的方法,以更好地利用硬件,例如我们发布的关于利用 TMA 实现 H100 加速的这篇博客,以及不同的工作分解(例如 StreamK 等持久化核技术),从而为我们基于 Triton 的方法带来更大的加速。对于 Flash Attention,我们计划探索 FlexAttention 和 FlashAttention-3,因为这些核中使用的技术可以帮助进一步弥合 Triton 和 CUDA 之间的差距。
我们还注意到,我们之前的工作表明,FP8 Triton GEMM 核在性能上相对于 cuBLAS FP8 GEMM 取得了可喜的成果,因此在未来的文章中,我们将探索端到端 FP8 LLM 推理。