作者:IBM 研究院:Antoni Viros i Martin, Brian Vaughan, Davis Wertheimer, Joshua Rosenkranz, Mudhakar Srivatsa, Nelson Mimura Gonzalez, Raghu Ganti, Supriyo Chakraborty, Zhuoran Liu Meta:Geeta Chauhan, Hamid Shojanazeri

在这篇博客中,我们讨论了如何使用 PyTorch 的原生优化技术来改进 Llama 2 系列模型的推理延迟,这些技术包括原生快速内核、torch compile 的编译转换以及用于分布式推理的张量并行。我们的方法在 70B LLaMa 模型上实现了单用户请求的 29ms/token 延迟(在 8 个 A100 GPU 上测量)。我们很高兴与社区分享我们的发现,并将代码公开在此处

背景

我们正处于生成式 AI 革命之中,拥有数百亿参数的大型语言模型正日益普及并可供使用。然而,在社区中公认的是,以成本高效的方式部署这些大型模型仍然是一个关键挑战。人们尝试了许多不同的方法,取得了不同程度的成功,并提供了不同的权衡。硬件特定的优化(例如,NVIDIA 的 Faster Transformer)仅限于特定的目标硬件,而依赖抽象层(例如,ONNX)的方法虽然支持任意模型,但效率会受到损失。随着去年 PyTorch compile 的推出,IBM 和 PyTorch 团队开始探索使用模型编译来进行推理优化,目标是降低生成模型的每 token 延迟。

模型选择

考虑到它们的受欢迎程度,我们选择在 Llama 2 系列模型上进行基准测试。我们感兴趣的模型及其与本文相关的超参数如下表所示

模型大小 隐藏层维度 头数 层数 注意力类型
7B 4096 32 32 MHA
13B 5120 40 40 MHA
70B 8192 64 80 GQA

这些模型是仅解码器模型,这意味着 token 是串行生成的,这通常通过 KV 缓存来加速。我们在延迟和吞吐量测量中采用了类似的方法。

推理方法

我们的推理目标是提供一条快速实现最佳延迟的路径,以跟上社区中新模型架构涌现的速度。PyTorch 原生方法很有吸引力,因为它在模型“覆盖范围”方面提供了最大的灵活性。我们注意到有四种正交技术可以加速推理:(a) 使用 compile 进行内核融合,(b) 更快的内核,(c) 用于大型模型的张量并行,以及 (d) 量化。在我们的方法中,我们使用了这四个杠杆中的前三个——compile 原生地与 SDPA 的更快内核以及一个定制的张量并行实现协同工作,在单用户使用 8 个 NVIDIA A100 GPU 测量时,在 70B 模型上实现了 29ms/token 的推理延迟。

全程编译!

PyTorch Compile 利用 tracing(跟踪)和 graph capture(图捕获)来减少 CPU 开销,在理想情况下,可以实现 CPU 到 GPU 的单次图执行/指令。然而,由于模型架构和 compile 不支持的操作,compile 经常会引入 graph breaks(图中断)。例如,诸如 einops 之类的复杂操作目前不受 compile 支持。类似地,张量并行推理可以在每一层引入 graph breaks,因为 compile 要求张量并行实现使用可跟踪的通信集合操作(collectives)。如果这些 graph breaks 没有被移除,编译后的 artifact(产物)的性能将会受到影响,甚至可能低于 eager mode(即时模式)执行。为了充分发挥编译后 artifact 的优势,需要移除 graph breaks。

下面,我们描述了如何对 70b Llama 2 模型进行此操作,以及为了使 compile 全程工作我们必须克服的挑战。

我们的第一次尝试是使用 torch.compile 编译开箱即用的 Llama 2 模型,但由于不支持复杂操作而失败。通过设置 TORCH_COMPILE_DEBUG = 1,我们发现 RoPE 位置编码使用了复数函数,导致了 graph breaks 和显著的速度下降。我们重写了 RoPE 函数,绕过了 torch.einsum(原始实现使用了 torch.polar,这也与 compile 冲突),而是使用了 torch.cos 和 torch.sin。

self.cached_freqs[dev_idx][alpha] = torch.stack(
            [
                torch.cos(freqs),
                -torch.sin(freqs),
                torch.sin(freqs),
                torch.cos(freqs),
            ],
            dim=2,
        ).view(*freqs.shape, 2, 2)

我们实现的频率计算

t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
t = t / self.scaling_factor

freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)

Hugging Face 实现的频率计算

RoPE 修复后,我们能够在单个 A100 GPU 上编译 7B 和 13B 模型,没有任何 graph breaks。

我们使用了 SDPA,这是 PyTorch 原生的搞笑注意力计算实现,并且启用了 tracing(用于 compile)。为了避免因使用 Python context 强制选择单一算法而导致的 graph breaks(这是推荐的方法),我们不得不使用 torch.backends.cuda.enable_*_sdp 函数。

attn = torch.nn.functional.scaled_dot_product_attention(
            queries,
            keys_e,
            values_e,
            attn_mask=attn_mask,
            dropout_p=self.p_dropout if self.training else 0.0,
            is_causal=is_causal_mask,
)

使用 SDPA 进行注意力计算

接下来,我们对更大的 70B 模型执行了相同的步骤,发现即使使用半精度,模型也无法适应单个 GPU,需要进行张量并行推理。对 70B 模型使用 torch.compile 导致了 162 个 graph breaks,原因是每层有两个 all-reduce,前向 embedding 有一个 all-gather,后向 embedding 有一个 all-gather。因此,我们没有看到推理延迟有任何显著改进。在撰写本文时,我们无法使用 PyTorch 的分布式张量实现,因为它不支持 compile。我们从头重写了张量并行代码,使其仅依赖于可跟踪的集合操作(collectives),从而使其能够与 compile 配合使用。经过最后这次修改后,PyTorch 编译器没有引入任何 graph breaks,我们看到了推理延迟的显著加速。具体来说,我们在使用 8 个 A100 GPU 时测得 Llama 70B 模型的延迟为 29ms/token,比未优化的推理提高了 2.4 倍。

服务方面

最后,这里需要注意的是,简单地对模型进行 compile 不足以在生产环境中提供服务。为了以高吞吐量实现上述性能,我们需要支持动态批处理(dynamic batching)、嵌套张量(nested tensors),以及一个预热阶段,在该阶段我们对分桶的序列长度(bucketized sequence lengths)进行预编译。我们正在研究这些方面,以便在生产环境中实现此类性能。

实验和测量

我们在两个不同的环境(IBM Cloud 和 AWS,均运行 OpenShift)中使用配备 8 个 80G A100 NVIDIA GPU 的节点进行所有测量。首先,我们比较了各种技术——eager mode(即时模式)、使用 SDPA Flash 内核、使用 Compile 以及使用 Compile 和 SDPA。对于 70B 模型,我们在使用 compile 和 SDPA 的情况下以张量并行模式运行。在这个实验中,我们使用 512 个 token 作为输入长度,生成 50 个 token。对于 7B 和 13B 模型,我们使用单个 A100 进行延迟测量,而对于 70B 模型,我们使用 8 个 A100。此外,对于 70B 模型,我们在 PyTorch compile 中使用了 reduce-overhead 选项,该选项使用 CudaGraphs 来减少 CPU 到 GPU 内核启动的开销;在 7B 和 13B 模型中使用 CudaGraphs 没有显示出任何益处(因此在此不报告)。从图 1 可以看出,compile 和 SDPA 提供了非常低的延迟,70B Llama 2 模型的延迟为 29ms/token。

Figure 1. Median latency across different techniques with sequence length 512 (measured on IBM Cloud A100 servers)

图 1:不同技术在序列长度为 512 时的中位数延迟(在 IBM Cloud A100 服务器上测量)

接下来,我们检查序列长度的影响,我们将序列长度从 1024 增加到 4096,观察到每 token 的中位数延迟呈亚线性增长,这表明当我们增加处理大型文档的上下文时,不会牺牲响应时间。

Figure 2. Median latency for compile+SDPA with different sequence lengths (Measured on A100s on AWS)

图 2:compile+SDPA 在不同序列长度时的中位数延迟(在 AWS 的 A100 上测量)

最后,随着批处理大小的增加,我们观察到响应延迟呈亚线性增长。对于 13B 模型,在批处理大小为 8 时,我们遇到了 OOM(内存不足)错误。对于 70B 模型,考虑到它在 8 个 GPU 上以张量并行模式运行,我们没有看到任何此类 OOM 问题。

Figure 3. Median latency for compile+SDPA with different batch sizes and sequence length fixed at 4096 (Measured on A100s on AWS)

图 3:compile+SDPA 在不同批处理大小和序列长度固定为 4096 时的中位数延迟(在 AWS 的 A100 上测量)

总结

我们演示了 PyTorch 的 compile 路径如何为 70B 模型推理带来超低延迟。下一步是结合上述杠杆,实现动态批处理和嵌套张量。

特别感谢来自 PyTorch 团队的 Edward Yang、Elias Ellison、Driss Guessous、Will Feng、Will Constable、Horace He、Less Wright 和 Andrew Gu,他们的 PR 评审和代码贡献使我们能够使用 PyTorch 原生方法实现如此低的延迟。我们感谢更广泛的 PyTorch 团队,他们一直在不懈努力使 PyTorch 变得更好,特别要感谢 SDPA 团队在快速内核上启用 tracing 和 compile,以及 compile 团队,他们一直密切指导我们如何解决问题以及修复问题(包括识别并提出 NVIDIA 驱动在 CUDA graphs 中的错误)。

推理延迟一直是 LLM 在关键企业工作流程中被采用的阻碍之一,而另一个主要阻碍是对安全性、可信度和治理的需求。IBM 关于 AI 安全和 LLM 风险的指南可在此处找到,Meta 的 LLaMa 负责任用户指南可在此处找到。

参考资料