在这篇博客中,我们讨论了如何使用 PyTorch 原生优化(例如原生快速内核、来自 torch compile 的编译转换和用于分布式推理的张量并行)来改善 Llama 2 模型系列的推理延迟。我们的方法在 70B LLaMa 模型上(在 8 个 A100 GPU 上测量)实现了单用户请求 29 毫秒/token 的延迟。我们很高兴与社区分享我们的发现,并在此处提供我们的代码:here。
背景
我们正处于生成式 AI 革命之中,拥有数百亿参数的大型语言模型变得商品化并可供使用。然而,社区普遍认识到,以经济高效的方式部署这些大型模型仍然是一个关键挑战。已经尝试了许多不同的方法,它们在不同程度上取得了成功,并提供了不同的权衡。特定于硬件的优化(例如,NVIDIA 的 Faster Transformer)仅限于特定的目标硬件,而依赖于抽象层的方法(例如,ONNX)可以支持任意模型,但会降低效率。随着去年 PyTorch compile 的推出,IBM 和 PyTorch 团队开始探索使用模型编译进行推理优化,目标是减少生成模型的每个 token 的延迟。
模型选择
鉴于 Llama 2 模型系列的受欢迎程度,我们选择以其为基准进行评测。我们感兴趣的模型以及与本博客相关的超参数在下表中给出
模型大小 | 隐藏维度 | 头数 | 层数 | 注意力类型 |
7B | 4096 | 32 | 32 | MHA |
13B | 5120 | 40 | 40 | MHA |
70B | 8192 | 64 | 80 | GQA |
这些模型仅为解码器模型,这意味着 token 以串行方式生成,这通常使用 KV 缓存来加速。我们在延迟和吞吐量测量中采用了类似的方法。
推理方法
我们的推理目标是提供一条快速实现最佳延迟的路径,以跟上社区中新模型架构涌现的速度。PyTorch 原生方法很有吸引力,因为它在模型“覆盖范围”方面具有最大的灵活性。我们注意到,有四种正交技术可以加速推理:(a)使用编译进行内核融合,(b)更快的内核,(c)用于更大模型的张量并行,以及(d)量化。在我们的方法中,我们使用了这四个杠杆中的前三个 - 原生编译与 SDPA 中的更快内核以及自定义张量并行实现协同工作,以在 70B 模型上实现 29 毫秒/token 的推理延迟,这是在 8 个 NVIDIA A100 GPU 上针对单用户测量的结果。
一路编译到底!
PyTorch Compile 利用跟踪和图捕获来减少 CPU 开销,在理想情况下,可以实现从 CPU 到 GPU 的单个图执行/指令。然而,由于模型架构和 compile 不支持的操作,编译通常会引入图中断。例如,今天 compile 不支持诸如 einops 之类的复杂操作。同样,张量并行推理可能会在每一层引入图中断,因为 compile 要求张量并行实现使用可跟踪的通信集合。如果这些图中断没有被消除,编译后的工件的性能将受到影响,甚至可能低于 eager 模式执行。为了充分利用编译后的工件,需要消除图中断。
下面,我们将描述我们如何为 70b Llama 2 模型做到这一点,以及我们必须克服哪些挑战才能使编译一直工作。
我们的第一次尝试是尝试使用 torch.compile 编译开箱即用的 Llama 2 模型,但由于不支持复杂操作而失败。使用 TORCH_COMPILE_DEBUG = 1,我们确定 RoPE 位置编码正在使用复数函数,导致图中断和明显的减速。我们重写了 RoPE 函数以绕过 torch.einsum(原始实现使用 torch.polar,这也与编译冲突),并改为使用 torch.cos 和 torch.sin。
self.cached_freqs[dev_idx][alpha] = torch.stack(
[
torch.cos(freqs),
-torch.sin(freqs),
torch.sin(freqs),
torch.cos(freqs),
],
dim=2,
).view(*freqs.shape, 2, 2)
我们频率计算的实现
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
t = t / self.scaling_factor
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
Hugging Face 频率计算的实现
一旦 RoPE 被修复,我们就能够在单个 A100 GPU 上编译 7B 和 13B 模型,而没有任何图中断。
我们使用了 SDPA,这是 PyTorch 原生实现的高效注意力计算,并启用了跟踪(用于编译)。为了避免由于使用 Python 上下文强制选择单个算法而导致的图中断(推荐的方法),我们必须使用 torch.backends.cuda.enable_*_sdp
函数。
attn = torch.nn.functional.scaled_dot_product_attention(
queries,
keys_e,
values_e,
attn_mask=attn_mask,
dropout_p=self.p_dropout if self.training else 0.0,
is_causal=is_causal_mask,
)
使用 SDPA 的注意力计算
接下来,我们对更大的 70B 模型运行了相同的步骤,发现即使使用半精度,该模型也无法容纳在单个 GPU 中,并且需要张量并行推理。对 70B 模型使用 torch.compile 导致 162 个图中断,这是由于每层两个 all-reduce、前向嵌入一个 all-gather 和反向嵌入一个 all-gather 造成的。因此,我们没有看到推理延迟的显着改善。在撰写本博客时,我们无法使用 PyTorch 的分布式张量实现,因为它不支持编译。我们从头开始重写了张量并行代码,使其仅依赖于可跟踪的集合通信,以便与编译一起工作。在最后一次更改之后,PyTorch 编译器没有引入任何图中断,我们看到了推理延迟的显着加速。具体而言,当使用 8 个 A100 GPU 时,我们测得 Llama 70B 模型的延迟为 29 毫秒/token,比未优化推理提高了 2.4 倍。
服务方面
最后,这里需要注意的一点是,仅仅对模型执行编译不足以在生产环境中服务该模型。为了以高吞吐量实现上述性能,我们需要支持动态批处理、嵌套张量,以及一个预热阶段,我们在其中为分桶序列长度进行预编译。我们正在努力解决这些方面,以在生产环境中实现这种性能。
实验和测量
我们在两个不同的环境(IBM Cloud 和 AWS,均运行 OpenShift)中使用配备 8 个 A100 NVIDIA GPU 和 80G 显卡的节点进行所有测量。首先,我们比较各种技术——eager 模式、使用 SDPA Flash 内核、使用 Compile 以及使用 Compile 和 SDPA。对于 70B 模型,我们在张量并行模式下使用 compile 和 SDPA 运行它。对于此实验,我们使用 512 个 token 作为输入长度,并生成 50 个 token。对于 7B 和 13B 模型,我们使用单个 A100 来测量延迟,而对于 70B 模型,我们使用 8 个 A100。此外,对于 70B 模型,我们使用 PyTorch compile 中的 reduce-overhead 选项,该选项使用 CudaGraphs 来减少 CPU 到 GPU 内核启动开销;在 7B 和 13B 模型中使用 CudaGraphs 没有显示任何好处(因此此处未报告)。我们从图 1 中观察到,compile 和 SDPA 提供了非常低的延迟,70B Llama 2 模型的延迟为 29 毫秒/token。
图 1:不同技术在序列长度为 512 时的中值延迟(在 IBM Cloud A100 服务器上测量)
接下来,我们检查序列长度的影响,我们将其从 1024 增加到 4096,并观察到每个 token 的中值延迟呈次线性增长,这表明当我们将上下文增加到大型文档时,我们不会牺牲响应时间。
图 2:不同序列长度下 compile+SDPA 的中值延迟(在 AWS 上的 A100 上测量)
最后,随着批大小的增加,我们观察到响应延迟呈次线性增长。对于 13B 模型,在批大小为 8 时,我们遇到 OOM 错误。对于 70B 模型,鉴于它在 8 个 GPU 上以张量并行方式运行,我们没有看到任何此类 OOM 问题。
图 3:不同批大小下 compile+SDPA 的中值延迟,序列长度固定为 4096(在 AWS 上的 A100 上测量)
最终想法
我们已经演示了 PyTorch 编译推理路径如何为 70B 模型推理展示超低延迟。下一步是使用上述杠杆启用动态批处理和嵌套张量。
特别感谢 PyTorch 团队的 Edward Yang、Elias Ellison、Driss Guessous、Will Feng、Will Constable、Horace He、Less Wright 和 Andrew Gu,他们的 PR 审查和代码贡献使我们能够使用 PyTorch 原生方法实现延迟。我们感谢更广泛的 PyTorch 团队,他们一直在不懈努力使 PyTorch 变得更好,特别感谢 SDPA 团队在快速内核上启用跟踪和编译,以及编译团队,他们一直在密切指导我们如何解决和修复问题(包括识别和提出 CUDA 图中的 NVIDIA 驱动程序错误)。
推理延迟一直是 LLM 在关键企业工作流程中采用的障碍之一,但另一个主要障碍是对安全性、可信赖性和治理的需求。IBM 的 AI 安全和 LLM 风险指南可在此处找到,Meta 的 LLaMa 负责任用户指南可在此处找到:here and here。
参考文献
- GitHub 资源: https://ibm.biz/fm-stack
- 在 PyTorch/XLA 上实现 LLaMa 65B 超低推理延迟的路径
- 速度,Python:任选其二。CUDA 图如何为深度学习启用快速 Python 代码
- IBM 关于 AI 伦理和信任的资源: https://www.ibm.com/downloads/cas/E5KE5KRZ
- Meta LLaMa 负责任用户指南: https://ai.meta.com/llama/responsible-use-guide/