跳转到主要内容
博客

无需 CUDA 的 LLM 推理

作者: 2024 年 9 月 4 日2024 年 11 月 11 日暂无评论

在本博客中,我们讨论了我们用来实现流行 LLM 模型(例如 Meta 的 Llama3-8BIBM 的 Granite-8B Code)FP16 推理的方法,其中 100% 的计算使用 OpenAI 的 Triton 语言执行。
对于使用基于 Triton 内核的模型进行的单 token 生成时间,我们能够达到与 CUDA 内核主导的工作流在 Nvidia H100 GPU 上的 Llama 和 Granite 上 0.76-0.78 倍的性能,以及在 Nvidia A100 GPU 上的 0.62-0.82 倍的性能。

为什么要探索使用 100% Triton?Triton 为 LLM 在不同类型的 GPU(NVIDIA、AMD,以及未来 Intel 和其他基于 GPU 的加速器)上运行提供了途径。它还提供了一个更高级的 Python 抽象层来编程 GPU,并使我们能够比使用供应商特定 API 更快地编写高性能内核。在本博客的其余部分,我们将分享如何实现 CUDA-free 计算,对单个内核进行微基准测试以进行比较,并讨论如何进一步改进未来的 Triton 内核以缩小差距。

图 1. 使用 Triton 和 CUDA 变体的 Llama3-8B 和 Granite-8B 在 NVIDIA H100 和 A100 上的推理吞吐量基准测试
设置:批大小 = 2,输入序列长度 = 512,输出序列长度 = 256

2.0 Transformer 块的组成

我们首先分解 Transformer 模型中发生的计算。下图显示了典型 Transformer 块的“内核”。

 图 2. 按核心内核划分的 Transformer 块

Llama3 架构的核心操作总结如下:

  1. RMSNorm
  2. 矩阵乘法:融合 QKV
  3. RoPE
  4. Attention
  5. 矩阵乘法:输出投影
  6. RMSNorm
  7. 矩阵乘法:融合门控 + 上采样投影
  8. 激活函数:SiLU
  9. 逐元素乘法
  10. 矩阵乘法:下采样投影

每个操作都在 GPU 上通过执行一个(或多个)内核进行计算。虽然这些内核的具体细节在不同的 transformer 模型中可能有所不同,但核心操作保持不变。例如,IBM 的 Granite 8B Code 模型在 MLP 层中使用偏置,这与 Llama3 不同。这些更改确实需要修改内核。典型的模型是这些 transformer 块与嵌入层堆叠在一起。

3.0 模型推理

典型的模型架构代码通过 PyTorch 启动的 python model.py 文件共享。在默认的 PyTorch 即时执行模式下,这些内核都使用 CUDA 执行。为了实现 Llama3-8B 和 Granite-8B 端到端推理的 100% Triton,我们需要编写和集成手写 Triton 内核,并利用 torch.compile(生成 Triton 操作)。首先,我们用编译器生成的 Triton 内核替换较小的操作,其次,我们用手写 Triton 内核替换更昂贵和复杂的计算(例如矩阵乘法和 Flash Attention)。

Torch.compile 自动为 RMSNorm、RoPE、SiLU 和逐元素乘法生成 Triton 内核。使用 Nsight Systems 等工具,我们可以观察到这些生成的内核;它们以微小的深绿色内核形式出现在矩阵乘法和 Attention 之间。

 图 3. Llama3-8B 使用 torch.compile 的跟踪,显示 CUDA 内核用于矩阵乘法和 Flash Attention

对于上述跟踪,我们注意到在 Llama3-8B 风格模型中构成 E2E 延迟 80% 的两个主要操作是矩阵乘法和 attention 内核,两者都仍然是 CUDA 内核。因此,为了弥补剩余的差距,我们用手写 Triton 内核替换了 matmul 和 attention 内核。

4.0 Triton SplitK GEMM 内核

对于线性层中的矩阵乘法,我们编写了一个自定义的 FP16 Triton GEMM(通用矩阵乘法)内核,该内核利用了 SplitK 工作分解。我们之前在其他博客中讨论过这种并行化,作为加速 LLM 推理解码部分的一种方式。

5.0 GEMM 内核调优

为了实现最佳性能,我们使用了穷尽搜索方法来调优我们的 SplitK GEMM 内核。Granite-8B 和 Llama3-8B 具有以下形状的线性层:

线性层形状 (in_features, out_features)
融合 QKV 投影(4096, 6144)
输出投影(4096, 4096)
融合门控 + 上投影(4096, 28672)
下投影(14336, 4096)

图 4. Granite-8B 和 Llama3-8B 线性层权重矩阵形状

这些线性层中的每一个都具有不同的权重矩阵形状。因此,为了获得最佳性能,必须针对每个形状配置文件对 Triton 内核进行调优。在对每个线性层进行调优后,我们能够在 Llama3-8B 和 Granite-8B 上实现比未调优的 Triton 内核 1.20 倍的 E2E 加速。

6.0 Flash Attention 内核

我们评估了一套现有 Triton Flash Attention 内核,它们具有不同的配置,包括:

  1. AMD Flash
  2. OpenAI Flash
  3. Dao AI Lab Flash
  4. XFormers Flash
  5. PyTorch FlexAttention

我们评估了这些内核的文本生成质量,首先在 Eager 模式下,然后(如果我们可以用标准方法 torch.compile 内核)在编译模式下。对于内核 2-5,我们注意到以下几点:

内核文本生成质量Torch.compile支持任意序列长度
AMD Flash连贯
OpenAI Flash不连贯未评估。正在调试 Eager 模式下的精度问题
Dao AI Lab Flash不连贯未评估。正在调试 Eager 模式下的精度问题
Xformers FlashDecoding在评估文本质量之前遇到编译错误进行中否(此内核针对解码进行了优化)
PyTorch FlexAttention连贯进行中进行中

图 5. 我们尝试与不同 Flash Attention 内核组合的表格

上表总结了我们开箱即用的观察结果。我们预计通过一些努力,内核 2-5 可以修改以满足上述标准。然而,这也表明,拥有一个可用于基准测试的内核通常只是使其可用作端到端生产内核的开始。
我们在后续测试中选择了 AMD Flash Attention 内核,因为它可以由 torch.compile 编译,并在 Eager 和编译模式下都能生成可读的输出。

为了满足 AMD Flash Attention 内核与 torch.compile 的兼容性,我们必须将其定义为 torch 自定义操作。此过程在此处详细解释 本教程链接讨论了如何封装简单的图像裁剪操作。但是,我们注意到封装更复杂的 Flash Attention 内核遵循类似的过程。两步方法如下:

  1. 将函数封装到 PyTorch 自定义操作中
  1. 向操作添加 FakeTensor 内核,该内核在给定 Flash (q, k 和 v) 输入张量形状的情况下,提供了一种计算 Flash 内核输出形状的方法

将 Triton Flash 内核定义为自定义操作后,我们成功地为我们的 E2E 运行对其进行了编译。

图 6. Llama3-8B 使用 torch.compile 的跟踪,在换入 Triton matmul 和 Triton Flash Attention 内核之后

从图 5 中,我们注意到,在集成了 SplitK 矩阵乘法内核、torch 操作封装的 Flash Attention 内核并运行 torch.compile 之后,我们能够实现一个使用 100% Triton 计算内核的前向传播。

7.0 端到端基准测试

我们使用 Granite-8B 和 Llama3-8B 模型在 NVIDIA H100 和 A100(单 GPU)上进行了端到端测量。我们使用两种不同的配置进行了基准测试。

Triton 内核配置使用

  1. Triton SplitK GEMM
  2. AMD Triton Flash Attention

CUDA 内核配置使用

  1. cuBLAS GEMM
  2. cuDNN Flash Attention – 缩放点积注意力 (SDPA)

我们发现在典型推理设置下,Eager 和 Torch 编译模式下的吞吐量和 token 间延迟如下:

GPU模型内核配置中位延迟(Eager)[毫秒/token]中位延迟(编译)[毫秒/token]
H100Granite-8BTriton27.4211.59
  CUDA18.849.50
 Llama3-8BTriton20.3610.61
  CUDA16.598.59
A100Granite-8BTriton53.4416.88
  CUDA37.1314.25
 Llama3-8BTriton44.4417.94
  CUDA32.4512.96

图 7. Granite-8B 和 Llama3-8B 在 H100 和 A100 上的单 Token 生成延迟,
(批大小 = 2,输入序列长度 = 512,输出序列长度 = 256)

总而言之,Triton 模型在 H100 上可以达到 CUDA 模型性能的 78%,在 A100 上可以达到 82%

性能差距可以用我们在下一节中讨论的 matmul 和 Flash Attention 的内核延迟来解释。

8.0 微基准测试

内核Triton [微秒]CUDA [微秒]
QKV 投影矩阵乘法2521
Flash Attention138
输出投影矩阵乘法2117
门控 + 上投影矩阵乘法8483
下投影矩阵乘法5842

图 8. Triton 和 CUDA 内核延迟比较(Llama3-8B 在 NVIDIA H100 上)
输入为任意提示(bs=1,提示 = 44 序列长度),解码延迟时间

从上图,我们注意到以下几点:

  1. Triton matmul 内核比 CUDA 慢 1.2-1.4 倍
  2. AMD 的 Triton Flash Attention 内核比 CUDA SDPA 慢 1.6 倍

这些结果强调了需要进一步提高 GEMM 和 Flash Attention 等核心原语内核的性能。我们将其留作未来的研究,因为最近的工作(例如 FlashAttention-3FlexAttention)提供了更好地利用底层硬件的方法以及 Triton 通道,我们希望能够在此基础上构建以实现更大的加速。为了说明这一点,我们将 FlexAttention 与 SDPA 和 AMD 的 Triton Flash 内核进行了比较。

我们正在努力验证 FlexAttention 的 E2E 性能。目前,Flex 的初步微基准测试显示,对于更长的上下文长度和解码问题形状(其中查询向量很小)有希望。

bar chart

图 9. FlexAttention 内核在 NVIDIA H100 SXM5 80GB 上的基准测试
(批处理=1, 头数=32, 序列长度=序列长度, 头维度=128)

9.0 未来工作

对于未来的工作,我们计划探索进一步优化 matmul 的方法,以更好地利用硬件,例如我们发表的关于 利用 TMA for H100 的博客,以及不同的工作分解(如 StreamK 等持久化内核技术)以获得基于 Triton 的方法更大的加速。对于 Flash Attention,我们计划探索 FlexAttention 和 FlashAttention-3,因为这些内核中使用的技术可以用来帮助进一步缩小 Triton 和 CUDA 之间的差距。
我们还注意到,我们之前的工作已经显示出 FP8 Triton GEMM 内核性能优于 cuBLAS FP8 GEMM 的可喜结果,因此在未来的帖子中,我们将探索 E2E FP8 LLM 推理。