作者:IBM 研究院:Antoni Viros i Martin, Brian Vaughan, Davis Wertheimer, Joshua Rosenkranz, Mudhakar Srivatsa, Nelson Mimura Gonzalez, Raghu Ganti, Supriyo Chakraborty, Zhuoran Liu Meta:Geeta Chauhan, Hamid Shojanazeri

在这篇博客中,我们讨论了如何使用 PyTorch 原生优化(例如原生快速内核、来自 torch compile 的编译转换和用于分布式推理的张量并行)来改善 Llama 2 模型系列的推理延迟。我们的方法在 70B LLaMa 模型上(在 8 个 A100 GPU 上测量)实现了单用户请求 29 毫秒/token 的延迟。我们很高兴与社区分享我们的发现,并在此处提供我们的代码:here

背景

我们正处于生成式 AI 革命之中,拥有数百亿参数的大型语言模型变得商品化并可供使用。然而,社区普遍认识到,以经济高效的方式部署这些大型模型仍然是一个关键挑战。已经尝试了许多不同的方法,它们在不同程度上取得了成功,并提供了不同的权衡。特定于硬件的优化(例如,NVIDIA 的 Faster Transformer)仅限于特定的目标硬件,而依赖于抽象层的方法(例如,ONNX)可以支持任意模型,但会降低效率。随着去年 PyTorch compile 的推出,IBM 和 PyTorch 团队开始探索使用模型编译进行推理优化,目标是减少生成模型的每个 token 的延迟。

模型选择

鉴于 Llama 2 模型系列的受欢迎程度,我们选择以其为基准进行评测。我们感兴趣的模型以及与本博客相关的超参数在下表中给出

模型大小 隐藏维度 头数 层数 注意力类型
7B 4096 32 32 MHA
13B 5120 40 40 MHA
70B 8192 64 80 GQA

这些模型仅为解码器模型,这意味着 token 以串行方式生成,这通常使用 KV 缓存来加速。我们在延迟和吞吐量测量中采用了类似的方法。

推理方法

我们的推理目标是提供一条快速实现最佳延迟的路径,以跟上社区中新模型架构涌现的速度。PyTorch 原生方法很有吸引力,因为它在模型“覆盖范围”方面具有最大的灵活性。我们注意到,有四种正交技术可以加速推理:(a)使用编译进行内核融合,(b)更快的内核,(c)用于更大模型的张量并行,以及(d)量化。在我们的方法中,我们使用了这四个杠杆中的前三个 - 原生编译与 SDPA 中的更快内核以及自定义张量并行实现协同工作,以在 70B 模型上实现 29 毫秒/token 的推理延迟,这是在 8 个 NVIDIA A100 GPU 上针对单用户测量的结果。

一路编译到底!

PyTorch Compile 利用跟踪和图捕获来减少 CPU 开销,在理想情况下,可以实现从 CPU 到 GPU 的单个图执行/指令。然而,由于模型架构和 compile 不支持的操作,编译通常会引入图中断。例如,今天 compile 不支持诸如 einops 之类的复杂操作。同样,张量并行推理可能会在每一层引入图中断,因为 compile 要求张量并行实现使用可跟踪的通信集合。如果这些图中断没有被消除,编译后的工件的性能将受到影响,甚至可能低于 eager 模式执行。为了充分利用编译后的工件,需要消除图中断。

下面,我们将描述我们如何为 70b Llama 2 模型做到这一点,以及我们必须克服哪些挑战才能使编译一直工作。

我们的第一次尝试是尝试使用 torch.compile 编译开箱即用的 Llama 2 模型,但由于不支持复杂操作而失败。使用 TORCH_COMPILE_DEBUG = 1,我们确定 RoPE 位置编码正在使用复数函数,导致图中断和明显的减速。我们重写了 RoPE 函数以绕过 torch.einsum(原始实现使用 torch.polar,这也与编译冲突),并改为使用 torch.cos 和 torch.sin。

self.cached_freqs[dev_idx][alpha] = torch.stack(
            [
                torch.cos(freqs),
                -torch.sin(freqs),
                torch.sin(freqs),
                torch.cos(freqs),
            ],
            dim=2,
        ).view(*freqs.shape, 2, 2)

我们频率计算的实现

t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
t = t / self.scaling_factor

freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)

Hugging Face 频率计算的实现

一旦 RoPE 被修复,我们就能够在单个 A100 GPU 上编译 7B 和 13B 模型,而没有任何图中断。

我们使用了 SDPA,这是 PyTorch 原生实现的高效注意力计算,并启用了跟踪(用于编译)。为了避免由于使用 Python 上下文强制选择单个算法而导致的图中断(推荐的方法),我们必须使用 torch.backends.cuda.enable_*_sdp 函数。

attn = torch.nn.functional.scaled_dot_product_attention(
            queries,
            keys_e,
            values_e,
            attn_mask=attn_mask,
            dropout_p=self.p_dropout if self.training else 0.0,
            is_causal=is_causal_mask,
)

使用 SDPA 的注意力计算

接下来,我们对更大的 70B 模型运行了相同的步骤,发现即使使用半精度,该模型也无法容纳在单个 GPU 中,并且需要张量并行推理。对 70B 模型使用 torch.compile 导致 162 个图中断,这是由于每层两个 all-reduce、前向嵌入一个 all-gather 和反向嵌入一个 all-gather 造成的。因此,我们没有看到推理延迟的显着改善。在撰写本博客时,我们无法使用 PyTorch 的分布式张量实现,因为它不支持编译。我们从头开始重写了张量并行代码,使其仅依赖于可跟踪的集合通信,以便与编译一起工作。在最后一次更改之后,PyTorch 编译器没有引入任何图中断,我们看到了推理延迟的显着加速。具体而言,当使用 8 个 A100 GPU 时,我们测得 Llama 70B 模型的延迟为 29 毫秒/token,比未优化推理提高了 2.4 倍。

服务方面

最后,这里需要注意的一点是,仅仅对模型执行编译不足以在生产环境中服务该模型。为了以高吞吐量实现上述性能,我们需要支持动态批处理、嵌套张量,以及一个预热阶段,我们在其中为分桶序列长度进行预编译。我们正在努力解决这些方面,以在生产环境中实现这种性能。

实验和测量

我们在两个不同的环境(IBM Cloud 和 AWS,均运行 OpenShift)中使用配备 8 个 A100 NVIDIA GPU 和 80G 显卡的节点进行所有测量。首先,我们比较各种技术——eager 模式、使用 SDPA Flash 内核、使用 Compile 以及使用 Compile 和 SDPA。对于 70B 模型,我们在张量并行模式下使用 compile 和 SDPA 运行它。对于此实验,我们使用 512 个 token 作为输入长度,并生成 50 个 token。对于 7B 和 13B 模型,我们使用单个 A100 来测量延迟,而对于 70B 模型,我们使用 8 个 A100。此外,对于 70B 模型,我们使用 PyTorch compile 中的 reduce-overhead 选项,该选项使用 CudaGraphs 来减少 CPU 到 GPU 内核启动开销;在 7B 和 13B 模型中使用 CudaGraphs 没有显示任何好处(因此此处未报告)。我们从图 1 中观察到,compile 和 SDPA 提供了非常低的延迟,70B Llama 2 模型的延迟为 29 毫秒/token。

Figure 1. Median latency across different techniques with sequence length 512 (measured on IBM Cloud A100 servers)

图 1:不同技术在序列长度为 512 时的中值延迟(在 IBM Cloud A100 服务器上测量)

接下来,我们检查序列长度的影响,我们将其从 1024 增加到 4096,并观察到每个 token 的中值延迟呈次线性增长,这表明当我们将上下文增加到大型文档时,我们不会牺牲响应时间。

Figure 2. Median latency for compile+SDPA with different sequence lengths (Measured on A100s on AWS)

图 2:不同序列长度下 compile+SDPA 的中值延迟(在 AWS 上的 A100 上测量)

最后,随着批大小的增加,我们观察到响应延迟呈次线性增长。对于 13B 模型,在批大小为 8 时,我们遇到 OOM 错误。对于 70B 模型,鉴于它在 8 个 GPU 上以张量并行方式运行,我们没有看到任何此类 OOM 问题。

Figure 3. Median latency for compile+SDPA with different batch sizes and sequence length fixed at 4096 (Measured on A100s on AWS)

图 3:不同批大小下 compile+SDPA 的中值延迟,序列长度固定为 4096(在 AWS 上的 A100 上测量)

最终想法

我们已经演示了 PyTorch 编译推理路径如何为 70B 模型推理展示超低延迟。下一步是使用上述杠杆启用动态批处理和嵌套张量。

特别感谢 PyTorch 团队的 Edward Yang、Elias Ellison、Driss Guessous、Will Feng、Will Constable、Horace He、Less Wright 和 Andrew Gu,他们的 PR 审查和代码贡献使我们能够使用 PyTorch 原生方法实现延迟。我们感谢更广泛的 PyTorch 团队,他们一直在不懈努力使 PyTorch 变得更好,特别感谢 SDPA 团队在快速内核上启用跟踪和编译,以及编译团队,他们一直在密切指导我们如何解决和修复问题(包括识别和提出 CUDA 图中的 NVIDIA 驱动程序错误)。

推理延迟一直是 LLM 在关键企业工作流程中采用的障碍之一,但另一个主要障碍是对安全性、可信赖性和治理的需求。IBM 的 AI 安全和 LLM 风险指南可在此处找到,Meta 的 LLaMa 负责任用户指南可在此处找到:here and here

参考文献