博客

无需 CUDA 的 LLM 推理

在本篇博客中,我们探讨了如何利用主流大模型(如 Meta 的 Llama3-8BIBM 的 Granite-8B Code)实现 FP16 推理的方法,其中 100% 的计算均使用 OpenAI 的 Triton 语言完成。
对于使用我们基于 Triton 内核的模型进行的单 Token 生成,在 Nvidia H100 GPU 上,其性能达到了主流 CUDA 内核工作流的 0.76-0.78倍;在 Nvidia A100 GPU 上则达到了 0.62-0.82倍

为什么要探索 100% 使用 Triton?Triton 为大模型在不同类型的 GPU(NVIDIA、AMD,未来还包括 Intel 及其他基于 GPU 的加速器)上运行提供了一条路径。它在 Python 中为 GPU 编程提供了更高层级的抽象,使我们能够比使用特定供应商 API 更快地编写高性能内核。在本文后续部分,我们将分享如何实现“无 CUDA”计算、对各个内核进行微基准测试对比,并讨论如何进一步改进 Triton 内核以缩小性能差距。

图 1. 在 NVIDIA H100 和 A100 上,Llama3-8B 和 Granite-8B 的 Triton 与 CUDA 变体的推理吞吐量基准测试
设置:批次大小(batch size)= 2,输入序列长度 = 512,输出序列长度 = 256

2.0 Transformer 块的构成

我们首先对 Transformer 模型中的计算过程进行拆解。下图展示了典型 Transformer 块的“内核”。

 图 2. 按核心内核划分的 Transformer 块

Llama3 架构的核心操作列表如下:

  1. RMSNorm
  2. 矩阵乘法:融合(Fused)QKV
  3. RoPE
  4. 注意力
  5. 矩阵乘法:输出投影(Output Projection)
  6. RMSNorm
  7. 矩阵乘法:融合(Fused)门控 + 上投影(Up Projection)
  8. 激活函数:SiLU
  9. 逐元素乘法(Element Wise Multiplication)
  10. 矩阵乘法:下投影(Down Projection)

每一个操作都通过执行一个(或多个)内核在 GPU 上完成计算。虽然这些内核的具体实现可能因 Transformer 模型而异,但核心操作保持不变。例如,IBM 的 Granite 8B Code 模型在 MLP 层中使用偏置(bias),这与 Llama3 不同。这些差异确实需要对内核进行调整。一个典型的模型是由这些 Transformer 块通过嵌入层(embedding layers)连接而成的堆叠结构。

3.0 模型推理

典型的模型架构代码是通过 PyTorch 运行的 Python model.py 文件来共享的。在默认的 PyTorch 即时执行(eager execution)模式下,这些内核均使用 CUDA 执行。为了实现 Llama3-8B 和 Granite-8B 端到端的 100% Triton 推理,我们需要编写并集成手写的 Triton 内核,同时利用 torch.compile(以生成 Triton 操作)。首先,我们将较小的操作替换为编译器生成的 Triton 内核;其次,我们将更昂贵且复杂的计算(如矩阵乘法和 Flash Attention)替换为手写的 Triton 内核。

Torch.compile 会自动为 RMSNorm、RoPE、SiLU 和逐元素乘法生成 Triton 内核。使用像 Nsight Systems 这样的工具,我们可以观察到这些生成的内核;它们在矩阵乘法和注意力算子之间显示为细小的深绿色内核。

 图 3. 使用 torch.compile 的 Llama3-8B 跟踪,显示用于矩阵乘法和 Flash Attention 的 CUDA 内核

对于上述跟踪,我们注意到 Llama3-8B 类模型中占据端到端(E2E)延迟 80% 的两个主要操作是矩阵乘法和注意力内核,且两者仍为 CUDA 内核。因此,为了缩小剩余差距,我们将这两个内核都替换为手写的 Triton 内核。

4.0 Triton SplitK GEMM 内核

对于线性层中的矩阵乘法,我们编写了一个自定义的 FP16 Triton GEMM(通用矩阵乘法)内核,利用了 SplitK 工作分解。我们之前在其他博客中讨论过这种并行化方法,将其作为加速大模型推理解码部分的手段。

5.0 GEMM 内核调优

为了获得最佳性能,我们使用了穷举搜索方法来调优我们的 SplitK GEMM 内核。Granite-8B 和 Llama3-8B 的线性层具有以下形状:

线性层形状 (输入特征, 输出特征)
融合 QKV 投影(4096, 6144)
输出投影(4096, 4096)
融合门控 + 上投影(4096, 28672)
下投影(14336, 4096)

图 4. Granite-8B 和 Llama3-8B 线性层权重矩阵形状

这些线性层中的每一个都有不同的权重矩阵形状。因此,为了达到最佳性能,Triton 内核必须针对每种形状配置文件进行调优。在对每个线性层进行调优后,我们能够使 Llama3-8B 和 Granite-8B 的端到端速度比未调优的 Triton 内核提升 1.20倍

6.0 Flash Attention 内核

我们评估了一套现有的 Triton Flash Attention 内核,其配置如下:

  1. AMD Flash
  2. OpenAI Flash
  3. Dao AI Lab Flash
  4. XFormers Flash
  5. PyTorch FlexAttention

我们评估了这些内核的文本生成质量,首先是在即时(eager)模式下,然后是(如果我们能通过标准方法进行 torch.compile)编译模式下。对于内核 2-5,我们注意到以下情况:

内核文本生成质量Torch.compile支持任意序列长度
AMD Flash连贯
OpenAI Flash不连贯未评估。正在处理中,先调试即时模式下的精度
Dao AI Lab Flash不连贯未评估。正在处理中,先调试即时模式下的精度
Xformers FlashDecoding在评估文本质量前遇到了编译错误正在处理否(该内核针对解码进行了优化)
PyTorch FlexAttention连贯正在处理正在处理

图 5. 我们尝试的不同 Flash Attention 内核组合表

上表总结了我们“开箱即用”观察到的情况。经过一番努力,我们预计内核 2-5 可以被修改以满足上述标准。然而,这也表明,拥有一个能用于基准测试的内核,通常仅仅是使其成为端到端生产级内核的开始。
我们选择在后续测试中使用 AMD Flash Attention 内核,因为它可以通过 torch.compile 编译,并且在即时模式和编译模式下都能生成清晰的输出。

为了使 AMD Flash Attention 内核兼容 torch.compile,我们必须将其定义为 PyTorch 自定义算子。此过程详细解释见此处。该教程讨论了如何封装简单的图像裁剪操作,但我们注意到,封装更复杂的 Flash Attention 内核遵循类似的流程。两步法如下:

  1. 将该函数封装为 PyTorch 自定义算子
  1. 为该算子添加一个 FakeTensor 内核,根据 Flash 注意力输入张量(q、k 和 v)的形状,提供一种计算 Flash 内核输出形状的方法

在将 Triton Flash 内核定义为自定义算子后,我们成功地为端到端运行对其进行了编译。

图 6. 替换 Triton matmul 和 Triton Flash Attention 内核后,使用 torch.compile 的 Llama3-8B 跟踪

从图 5 中我们注意到,现在,在集成了 SplitK 矩阵乘法内核和经算子封装的 Flash Attention 内核并运行 torch.compile 后,我们能够实现一个 100% 使用 Triton 计算内核的前向传播。

7.0 端到端基准测试

我们在 NVIDIA H100 和 A100(单 GPU)上对 Granite-8B 和 Llama3-8B 模型进行了端到端测量。我们使用了两种不同的配置进行基准测试。

Triton 内核配置使用:

  1. Triton SplitK GEMM
  2. AMD Triton Flash Attention

CUDA 内核配置使用:

  1. cuBLAS GEMM
  2. cuDNN Flash Attention – 缩放点积注意力 (SDPA)

在典型的推理设置下,我们针对即时模式和编译模式发现了以下吞吐量和 Token 间延迟:

GPU模型内核配置中位延迟 (即时) [ms/tok]中位延迟 (编译) [ms/tok]
H100Granite-8BTriton27.4211.59
  CUDA18.849.50
 Llama3-8BTriton20.3610.61
  CUDA16.598.59
A100Granite-8BTriton53.4416.88
  CUDA37.1314.25
 Llama3-8BTriton44.4417.94
  CUDA32.4512.96

图 7. Granite-8B 和 Llama3-8B 在 H100 和 A100 上的单 Token 生成延迟
(批次大小 = 2,输入序列长度 = 512,输出序列长度 = 256)

总结来说,在 H100 上,Triton 模型最高可达到 CUDA 模型 78% 的性能,在 A100 上最高可达到 82%

性能差距可以通过我们在下一节讨论的 matmul 和 Flash Attention 内核延迟来解释。

8.0 微基准测试

内核Triton [us]CUDA [us]
QKV 投影 Matmul2521
Flash Attention138
输出投影 Matmul2117
门控 + 上投影 Matmul8483
下投影 Matmul5842

图 8. Triton 和 CUDA 内核延迟对比 (Llama3-8B 在 NVIDIA H100 上)
输入为随机 Prompt(bs=1,Prompt = 44 序列长度),解码延迟时间

基于以上数据,我们得出以下结论:

  1. Triton matmul 内核比 CUDA 慢 1.2-1.4 倍
  2. AMD 的 Triton Flash Attention 内核比 CUDA SDPA 慢 1.6 倍

这些结果凸显了进一步提高核心原语(如 GEMM 和 Flash Attention)性能的必要性。我们将此留作未来的研究,因为最近的作品(例如 FlashAttention-3, FlexAttention)提供了更好地利用底层硬件的方法,以及我们希望能在其基础上构建以实现更大加速的 Triton 路径。为了说明这一点,我们将 FlexAttention 与 SDPA 和 AMD 的 Triton Flash 内核进行了比较。

我们正在努力验证 FlexAttention 的端到端性能。目前,Flex 的初步微基准测试在较长上下文长度和解码问题形状(Query 向量较小时)中展现了潜力。

bar chart

图 9. NVIDIA H100 SXM5 80GB 上的 FlexAttention 内核基准测试
(批次=1, 注意力头数=32, 序列长度=seq_len, 头维度=128)

9.0 未来工作

未来,我们计划探索更好地利用硬件优化 matmul 的方法,例如我们发布的关于利用 H100 TMA 单元的博客,以及不同的工作分解(如 StreamK 等持久化内核技术),以使我们的 Triton 方法获得更大的加速。对于 Flash Attention,我们计划探索 FlexAttention 和 FlashAttention-3,因为这些内核中使用的技术可以被借鉴,以帮助进一步缩小 Triton 与 CUDA 之间的性能差距。
我们还注意到,我们之前的工作已展示了 FP8 Triton GEMM 内核相较于 cuBLAS FP8 GEMM 的出色性能,因此在未来的文章中,我们将探索端到端的 FP8 大模型推理。