跳转到主要内容

在本博客中,我们将讨论如何使用 PyTorch 原生优化(例如原生快速内核、torch compile 的编译转换以及分布式推理的张量并行)来改善 Llama 2 系列模型的推理延迟。我们的方法使得 70B LLaMa 模型在单用户请求下实现 29 毫秒/token 的延迟(在 8 块 A100 GPU 上测量)。我们很高兴与社区分享我们的发现,并在此提供我们的代码。

背景

我们正处于生成式 AI 革命之中,拥有数百亿参数的大型语言模型正在商品化并可供使用。然而,社区普遍认为,以成本高效的方式部署这些大型模型仍然是一个关键挑战。人们尝试了许多不同的方法,取得了不同程度的成功并提供了不同的权衡。特定硬件优化(例如 NVIDIA 的 Faster Transformer)仅限于特定的目标硬件,而依赖于抽象层的方法(例如 ONNX)虽然支持任意模型,但效率低下。随着去年 PyTorch compile 的推出,IBM 和 PyTorch 团队开始探索使用模型编译进行推理优化,旨在降低生成模型的每 token 延迟。

模型选择

鉴于 Llama 2 系列模型的受欢迎程度,我们选择对其进行基准测试。我们感兴趣的模型及其与本博客相关的超参数如下表所示:

模型大小隐藏维度头数量层数量注意力类型
7B40963232MHA
13B51204040MHA
70B81926480GQA

这些模型是仅解码器模型,这意味着 token 以序列化方式生成,通常使用 KV 缓存来加速。我们在延迟和吞吐量测量中也采用了类似的方法。

推理方法

我们进行推理的目标是提供一条快速实现最佳延迟的途径,以跟上社区中新模型架构出现的快速步伐。PyTorch 原生方法具有吸引力,因为它在模型“覆盖范围”方面提供了最大的灵活性。我们注意到有四种正交技术可以加速推理:(a) 使用编译进行内核融合,(b) 更快的内核,(c) 用于更大模型的张量并行,以及 (d) 量化。在我们的方法中,我们使用了这四种杠杆中的前三种——编译与 SDPA 的更快内核原生协作,以及一个定制的张量并行实现,它们协同工作,在 8 块 NVIDIA A100 GPU 上,单用户情况下,70B 模型的推理延迟达到 29 毫秒/token。

全程编译!

PyTorch Compile 利用跟踪和图捕获来减少 CPU 开销,在理想情况下,CPU 到 GPU 的单个图执行/指令即可完成。然而,编译通常会因为模型架构和编译不支持的操作而引入图中断。例如,目前编译不支持 einops 等复杂操作。同样,张量并行推理可能会在每一层引入图中断,因为编译要求张量并行实现使用可跟踪的通信集合操作。如果这些图中断不消除,编译后的工件性能将受到影响,甚至可能低于 eager 模式执行。为了充分利用编译后的工件,需要消除图中断。

下面,我们将描述我们如何为 70b Llama 2 模型完成此操作,以及为使编译全程工作而必须克服的挑战。

我们的第一次尝试是使用 torch.compile 编译开箱即用的 Llama 2 模型,但由于不支持复杂操作而失败。使用 TORCH_COMPILE_DEBUG = 1,我们发现 RoPE 位置编码使用了复数函数,导致图中断和显著的性能下降。我们重写了 RoPE 函数以绕过 torch.einsum(原始实现使用了 torch.polar,这与编译也冲突),并改用 torch.cos 和 torch.sin。

self.cached_freqs[dev_idx][alpha] = torch.stack(
            [
                torch.cos(freqs),
                -torch.sin(freqs),
                torch.sin(freqs),
                torch.cos(freqs),
            ],
            dim=2,
        ).view(*freqs.shape, 2, 2)

我们实现的频率计算

t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
t = t / self.scaling_factor

freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)

Hugging Face 实现的频率计算

RoPE 修复后,我们可以在单块 A100 GPU 上编译 7B 和 13B 模型,没有任何图中断。

我们使用了 SDPA,这是 PyTorch 原生的高效注意力计算实现,并启用了跟踪(用于编译)。为了避免与使用 Python 上下文强制选择单一算法(推荐方式)相关的图中断,我们不得不使用 torch.backends.cuda.enable_*_sdp 函数。

attn = torch.nn.functional.scaled_dot_product_attention(
            queries,
            keys_e,
            values_e,
            attn_mask=attn_mask,
            dropout_p=self.p_dropout if self.training else 0.0,
            is_causal=is_causal_mask,
)

使用 SDPA 的注意力计算

接下来,我们对更大的 70B 模型执行了相同的步骤,发现即使使用半精度,模型也无法在单个 GPU 中容纳,需要张量并行推理。对 70B 模型使用 torch.compile 导致了 162 个图中断,原因是在每层有两个 all-reduce,一个用于正向嵌入的 all-gather,以及一个用于反向嵌入的 all-gather。因此,我们没有看到推理延迟有显著改善。在撰写本博客时,我们无法使用 PyTorch 的分布式张量实现,因为它不支持编译。我们从头重写了张量并行代码,使其仅依赖于可跟踪的集合操作,以使其与编译兼容。在进行最后一次更改后,PyTorch 编译器没有引入任何图中断,我们看到了推理延迟的显著加速。具体来说,我们测量了 Llama 70B 模型在使用 8 块 A100 GPU 时的延迟为 29 毫秒/token,比未优化推理提高了 2.4 倍。

服务方面

最后,这里需要指出的是,仅仅对模型进行编译不足以在生产环境中提供模型服务。为了在高吞吐量下实现上述性能,我们需要支持动态批处理、嵌套张量,并有一个预热阶段,我们预先编译分桶序列长度。我们正在努力解决这些方面,以在生产环境中实现这种性能。

实验和测量

我们在两种不同环境(IBM Cloud 和 AWS,均运行 OpenShift)中使用配备 8 块 80G A100 NVIDIA GPU 的节点进行所有测量。首先,我们比较了各种技术—— eager 模式、SDPA Flash 内核、编译以及编译和 SDPA。对于 70B 模型,我们以张量并行模式运行它,并使用编译和 SDPA。对于此实验,我们使用 512 个 token 作为输入长度,生成 50 个 token。对于 7B 和 13B 模型,我们使用单个 A100 来测量延迟,而对于 70B 模型,我们使用 8 个 A100。此外,对于 70B 模型,我们在 PyTorch compile 中使用 reduce-overhead 选项,该选项使用 CudaGraphs 来减少 CPU 到 GPU 内核启动开销;在 7B 和 13B 模型中使用 CudaGraphs 未显示任何好处(因此未在此处报告)。我们从图 1 中观察到,编译和 SDPA 提供了非常低的延迟,其中 70B Llama 2 模型的延迟为 29 毫秒/token。

Figure 1. Median latency across different techniques with sequence length 512 (measured on IBM Cloud A100 servers)

图 1:不同技术下的中位延迟,序列长度为 512(在 IBM Cloud A100 服务器上测量)

接下来,我们研究序列长度的影响,将其从 1024 增加到 4096,观察到每 token 的中位延迟呈次线性增加,这表明当我们将上下文增加到大型文档时,我们不会牺牲响应时间。

Figure 2. Median latency for compile+SDPA with different sequence lengths (Measured on A100s on AWS)

图 2:编译+SDPA 在不同序列长度下的中位延迟(在 AWS 的 A100 上测量)

最后,随着批量大小的增加,我们观察到响应延迟呈次线性增加。对于 13B 模型,当批量大小为 8 时,我们遇到了 OOM(内存不足)。对于 70B 模型,鉴于它在 8 个 GPU 上以张量并行方式运行,我们没有看到任何此类 OOM 问题。

Figure 3. Median latency for compile+SDPA with different batch sizes and sequence length fixed at 4096 (Measured on A100s on AWS)

图 3:编译+SDPA 在不同批量大小和序列长度固定为 4096 下的中位延迟(在 AWS 的 A100 上测量)

最终想法

我们已经展示了 PyTorch 编译推理路径如何为 70B 模型推理带来超低延迟。下一步是利用上述杠杆实现动态批处理和嵌套张量。

特别感谢 PyTorch 团队的 Edward Yang、Elias Ellison、Driss Guessous、Will Feng、Will Constable、Horace He、Less Wright 和 Andrew Gu,他们的 PR 审查和代码贡献使我们能够使用 PyTorch 原生方法实现这些延迟。我们感谢更广泛的 PyTorch 团队,他们不知疲倦地致力于改进 PyTorch,特别鸣谢 SDPA 团队在快速内核上启用跟踪和编译,以及编译团队密切指导我们如何解决和修复问题(包括识别和提出 CUDA 图中的 NVIDIA 驱动程序错误)。

推理延迟一直是 LLM 在关键企业工作流程中普及的障碍之一,但另一个主要障碍是安全、可信度和治理的需求。IBM 的 AI 安全和 LLM 风险指南可在此处找到,Meta 的 LLaMa 负责任使用指南可在此处找到

参考文献