在本篇博客中,我们探讨了如何利用主流大模型(如 Meta 的 Llama3-8B 和 IBM 的 Granite-8B Code)实现 FP16 推理的方法,其中 100% 的计算均使用 OpenAI 的 Triton 语言完成。
对于使用我们基于 Triton 内核的模型进行的单 Token 生成,在 Nvidia H100 GPU 上,其性能达到了主流 CUDA 内核工作流的 0.76-0.78倍;在 Nvidia A100 GPU 上则达到了 0.62-0.82倍。
为什么要探索 100% 使用 Triton?Triton 为大模型在不同类型的 GPU(NVIDIA、AMD,未来还包括 Intel 及其他基于 GPU 的加速器)上运行提供了一条路径。它在 Python 中为 GPU 编程提供了更高层级的抽象,使我们能够比使用特定供应商 API 更快地编写高性能内核。在本文后续部分,我们将分享如何实现“无 CUDA”计算、对各个内核进行微基准测试对比,并讨论如何进一步改进 Triton 内核以缩小性能差距。

图 1. 在 NVIDIA H100 和 A100 上,Llama3-8B 和 Granite-8B 的 Triton 与 CUDA 变体的推理吞吐量基准测试
设置:批次大小(batch size)= 2,输入序列长度 = 512,输出序列长度 = 256
2.0 Transformer 块的构成
我们首先对 Transformer 模型中的计算过程进行拆解。下图展示了典型 Transformer 块的“内核”。
图 2. 按核心内核划分的 Transformer 块
Llama3 架构的核心操作列表如下:
- RMSNorm
- 矩阵乘法:融合(Fused)QKV
- RoPE
- 注意力
- 矩阵乘法:输出投影(Output Projection)
- RMSNorm
- 矩阵乘法:融合(Fused)门控 + 上投影(Up Projection)
- 激活函数:SiLU
- 逐元素乘法(Element Wise Multiplication)
- 矩阵乘法:下投影(Down Projection)
每一个操作都通过执行一个(或多个)内核在 GPU 上完成计算。虽然这些内核的具体实现可能因 Transformer 模型而异,但核心操作保持不变。例如,IBM 的 Granite 8B Code 模型在 MLP 层中使用偏置(bias),这与 Llama3 不同。这些差异确实需要对内核进行调整。一个典型的模型是由这些 Transformer 块通过嵌入层(embedding layers)连接而成的堆叠结构。
3.0 模型推理
典型的模型架构代码是通过 PyTorch 运行的 Python model.py 文件来共享的。在默认的 PyTorch 即时执行(eager execution)模式下,这些内核均使用 CUDA 执行。为了实现 Llama3-8B 和 Granite-8B 端到端的 100% Triton 推理,我们需要编写并集成手写的 Triton 内核,同时利用 torch.compile(以生成 Triton 操作)。首先,我们将较小的操作替换为编译器生成的 Triton 内核;其次,我们将更昂贵且复杂的计算(如矩阵乘法和 Flash Attention)替换为手写的 Triton 内核。
Torch.compile 会自动为 RMSNorm、RoPE、SiLU 和逐元素乘法生成 Triton 内核。使用像 Nsight Systems 这样的工具,我们可以观察到这些生成的内核;它们在矩阵乘法和注意力算子之间显示为细小的深绿色内核。
图 3. 使用 torch.compile 的 Llama3-8B 跟踪,显示用于矩阵乘法和 Flash Attention 的 CUDA 内核
对于上述跟踪,我们注意到 Llama3-8B 类模型中占据端到端(E2E)延迟 80% 的两个主要操作是矩阵乘法和注意力内核,且两者仍为 CUDA 内核。因此,为了缩小剩余差距,我们将这两个内核都替换为手写的 Triton 内核。
4.0 Triton SplitK GEMM 内核
对于线性层中的矩阵乘法,我们编写了一个自定义的 FP16 Triton GEMM(通用矩阵乘法)内核,利用了 SplitK 工作分解。我们之前在其他博客中讨论过这种并行化方法,将其作为加速大模型推理解码部分的手段。
5.0 GEMM 内核调优
为了获得最佳性能,我们使用了穷举搜索方法来调优我们的 SplitK GEMM 内核。Granite-8B 和 Llama3-8B 的线性层具有以下形状:
| 线性层 | 形状 (输入特征, 输出特征) |
|---|---|
| 融合 QKV 投影 | (4096, 6144) |
| 输出投影 | (4096, 4096) |
| 融合门控 + 上投影 | (4096, 28672) |
| 下投影 | (14336, 4096) |
图 4. Granite-8B 和 Llama3-8B 线性层权重矩阵形状
这些线性层中的每一个都有不同的权重矩阵形状。因此,为了达到最佳性能,Triton 内核必须针对每种形状配置文件进行调优。在对每个线性层进行调优后,我们能够使 Llama3-8B 和 Granite-8B 的端到端速度比未调优的 Triton 内核提升 1.20倍。
6.0 Flash Attention 内核
我们评估了一套现有的 Triton Flash Attention 内核,其配置如下:
我们评估了这些内核的文本生成质量,首先是在即时(eager)模式下,然后是(如果我们能通过标准方法进行 torch.compile)编译模式下。对于内核 2-5,我们注意到以下情况:
| 内核 | 文本生成质量 | Torch.compile | 支持任意序列长度 |
|---|---|---|---|
| AMD Flash | 连贯 | 是 | 是 |
| OpenAI Flash | 不连贯 | 未评估。正在处理中,先调试即时模式下的精度 | 否 |
| Dao AI Lab Flash | 不连贯 | 未评估。正在处理中,先调试即时模式下的精度 | 是 |
| Xformers FlashDecoding | 在评估文本质量前遇到了编译错误 | 正在处理 | 否(该内核针对解码进行了优化) |
| PyTorch FlexAttention | 连贯 | 正在处理 | 正在处理 |
图 5. 我们尝试的不同 Flash Attention 内核组合表
上表总结了我们“开箱即用”观察到的情况。经过一番努力,我们预计内核 2-5 可以被修改以满足上述标准。然而,这也表明,拥有一个能用于基准测试的内核,通常仅仅是使其成为端到端生产级内核的开始。
我们选择在后续测试中使用 AMD Flash Attention 内核,因为它可以通过 torch.compile 编译,并且在即时模式和编译模式下都能生成清晰的输出。
为了使 AMD Flash Attention 内核兼容 torch.compile,我们必须将其定义为 PyTorch 自定义算子。此过程详细解释见此处。该教程讨论了如何封装简单的图像裁剪操作,但我们注意到,封装更复杂的 Flash Attention 内核遵循类似的流程。两步法如下:
- 将该函数封装为 PyTorch 自定义算子

- 为该算子添加一个 FakeTensor 内核,根据 Flash 注意力输入张量(q、k 和 v)的形状,提供一种计算 Flash 内核输出形状的方法

在将 Triton Flash 内核定义为自定义算子后,我们成功地为端到端运行对其进行了编译。

图 6. 替换 Triton matmul 和 Triton Flash Attention 内核后,使用 torch.compile 的 Llama3-8B 跟踪
从图 5 中我们注意到,现在,在集成了 SplitK 矩阵乘法内核和经算子封装的 Flash Attention 内核并运行 torch.compile 后,我们能够实现一个 100% 使用 Triton 计算内核的前向传播。
7.0 端到端基准测试
我们在 NVIDIA H100 和 A100(单 GPU)上对 Granite-8B 和 Llama3-8B 模型进行了端到端测量。我们使用了两种不同的配置进行基准测试。
Triton 内核配置使用:
- Triton SplitK GEMM
- AMD Triton Flash Attention
CUDA 内核配置使用:
- cuBLAS GEMM
- cuDNN Flash Attention – 缩放点积注意力 (SDPA)
在典型的推理设置下,我们针对即时模式和编译模式发现了以下吞吐量和 Token 间延迟:
| GPU | 模型 | 内核配置 | 中位延迟 (即时) [ms/tok] | 中位延迟 (编译) [ms/tok] |
|---|---|---|---|---|
| H100 | Granite-8B | Triton | 27.42 | 11.59 |
| CUDA | 18.84 | 9.50 | ||
| Llama3-8B | Triton | 20.36 | 10.61 | |
| CUDA | 16.59 | 8.59 | ||
| A100 | Granite-8B | Triton | 53.44 | 16.88 |
| CUDA | 37.13 | 14.25 | ||
| Llama3-8B | Triton | 44.44 | 17.94 | |
| CUDA | 32.45 | 12.96 |
图 7. Granite-8B 和 Llama3-8B 在 H100 和 A100 上的单 Token 生成延迟
(批次大小 = 2,输入序列长度 = 512,输出序列长度 = 256)
总结来说,在 H100 上,Triton 模型最高可达到 CUDA 模型 78% 的性能,在 A100 上最高可达到 82%。
性能差距可以通过我们在下一节讨论的 matmul 和 Flash Attention 内核延迟来解释。
8.0 微基准测试
| 内核 | Triton [us] | CUDA [us] |
|---|---|---|
| QKV 投影 Matmul | 25 | 21 |
| Flash Attention | 13 | 8 |
| 输出投影 Matmul | 21 | 17 |
| 门控 + 上投影 Matmul | 84 | 83 |
| 下投影 Matmul | 58 | 42 |
图 8. Triton 和 CUDA 内核延迟对比 (Llama3-8B 在 NVIDIA H100 上)
输入为随机 Prompt(bs=1,Prompt = 44 序列长度),解码延迟时间
基于以上数据,我们得出以下结论:
- Triton matmul 内核比 CUDA 慢 1.2-1.4 倍
- AMD 的 Triton Flash Attention 内核比 CUDA SDPA 慢 1.6 倍
这些结果凸显了进一步提高核心原语(如 GEMM 和 Flash Attention)性能的必要性。我们将此留作未来的研究,因为最近的作品(例如 FlashAttention-3, FlexAttention)提供了更好地利用底层硬件的方法,以及我们希望能在其基础上构建以实现更大加速的 Triton 路径。为了说明这一点,我们将 FlexAttention 与 SDPA 和 AMD 的 Triton Flash 内核进行了比较。
我们正在努力验证 FlexAttention 的端到端性能。目前,Flex 的初步微基准测试在较长上下文长度和解码问题形状(Query 向量较小时)中展现了潜力。

图 9. NVIDIA H100 SXM5 80GB 上的 FlexAttention 内核基准测试
(批次=1, 注意力头数=32, 序列长度=seq_len, 头维度=128)
9.0 未来工作
未来,我们计划探索更好地利用硬件优化 matmul 的方法,例如我们发布的关于利用 H100 TMA 单元的博客,以及不同的工作分解(如 StreamK 等持久化内核技术),以使我们的 Triton 方法获得更大的加速。对于 Flash Attention,我们计划探索 FlexAttention 和 FlashAttention-3,因为这些内核中使用的技术可以被借鉴,以帮助进一步缩小 Triton 与 CUDA 之间的性能差距。
我们还注意到,我们之前的工作已展示了 FP8 Triton GEMM 内核相较于 cuBLAS FP8 GEMM 的出色性能,因此在未来的文章中,我们将探索端到端的 FP8 大模型推理。