作者:Milad Mohammadi、Jiewen Tan、Liyang Lu、Siyuan Liu、Yeounoh Chung、Wonjoo Lee、Manfei Bai、Steven Krawczyk、Shauheen Zahirazami、Alex Wertheim、Meghan Cowan、Jack Cao、Joe Spisak

背景 & 最新技术

在自然语言处理 (NLP) 领域,语言模型旨在通过使用一系列过去的输入标记来生成标记(例如单词)。大型语言模型 (LLM) 是该领域最新的深度学习创新,旨在以类似人类的方式生成文本。这些模型通常使用 transformers 来提高其对大量输入标记的注意力。

LLaMA,由 Meta AI 开源,是一个强大的基础 LLM,在超过 1T 个标记上进行训练。LLaMA 可以与许多同类最佳模型相媲美,例如 GPT-3ChinchillaPaLMLLaMA (13B) 的性能优于 GPT-3 (175B),突显了其从每个模型参数中提取更多计算能力的能力。

在本博文中,我们使用 LLaMA 作为示例模型来演示 PyTorch/XLA 在 LLM 推理方面的能力。我们讨论了此处讨论的计算技术和优化如何将 Google Cloud TPU v4 (v4-16) 驱动的 65B 参数 LLaMA 模型的推理延迟提高 6.4 倍。

模型概述

我们演示了 PyTorch/XLA 在 LLaMA(Meta 最新的 LLM)上的性能能力。我们展示了一系列常见 LLaMA 配置的性能优化。请注意,公共领域中缺少 175B 参数模型配置。对于下面提到的 175B 参数模型,我们将 OPT 175B 模型配置应用于 LLaMA 代码库。除非另有说明,在所有配置中,我们都对权重和激活使用 max_seq_len=256dtype=bfloat16

表 1:本文中探讨的模型配置

LLaMA 模型超参数
# 参数 维度 N 头 N 层 最大序列长度
7B 4,096 32 32 256
33B 6,656 52 60 256
65B 8,192 64 80 256
175B 12,288 96 96 256

LLM 的性能挑战

LLM 具有一些使其难以进行编译器优化的属性。(a) LLM 使用自回归解码来生成基于先前标记的下一个标记;这意味着提示张量和教练具有动态形状。(b) LLM 必须处理可变的输入提示长度,而不会因输入张量形状更改而触发重新编译;输入张量必须正确分桶和填充,以避免重新编译。(c) LLM 通常需要比单个 TPU(或 GPU)设备可以支持的内存更多的内存。需要模型分片方案来使模型适合分布式计算架构。例如,具有 65B 参数的 LLaMA 模型可以装配在 v4-16 Cloud TPU 上,这与 8 个 A100 GPU 相当。(d) 在生产环境中运行 LLM 可能很昂贵;提高每总拥有成本 (Perf/TCO) 性能的一种方法是通过量化;量化可能会降低硬件要求。

PyTorch/XLA 中的推理技术堆栈

我们的目标是为 AI 社区提供高性能推理堆栈。PyTorch/XLA 与 TorchDynamoPjRtOpenXLA 和各种模型并行方案集成。TorchDynamo 消除了运行时的跟踪开销,PjRt 实现了高效的主机设备通信;PyTorch/XLA 可跟踪集合通过 TorchDynamo 在 LLaMA 上启用模型和数据并行性。要尝试我们的结果,请使用我们的自定义 torchtorch-xla wheels 来重现我们的 LLaMA 推理解决方案。PyTorch/XLA 2.1 将默认支持本文中讨论的功能。

并行计算

FairScale 分片

LLaMA 使用 FairScale 模型分片 API (fairscale.nn.model_parallel.layers)。我们使用 PyTorch/XLA 通信集合 (CC) 操作(例如 all-reduce)构建了此 API 的等效表示,以在加速器之间通信程序状态(例如激活)。TorchDynamo 目前不完全支持捕获 CC 操作(又名 可跟踪集合)。如果没有此支持,TorchDynamo FX 图将在每次设备通信时被切割,这意味着在每一模型层。图切割会导致性能损失,因为底层 XLA 编译器会失去完整的图优化机会。为了解决这个问题,我们通过将调度程序集合集成到我们现有的 CC API 中来提供 PyTorch/XLA 可跟踪集合。不同之处在于,鉴于 PyTorch/XLA 的延迟执行性质,我们不需要在集合之后插入 c10d.wait() 操作。借助对可跟踪集合的支持,PyTorch/XLA 允许在 TorchDynamo 中生成奇异 FX 图。

PyTorch/XLA 上的自回归解码

LLM 需要自回归解码来将先前的单词作为提示输入,以预测下一个标记。自回归解码会导致无界的动态形状问题,这反过来会导致每次提示都重新编译。我们优化了 LLaMA 自回归解码器,使其使用固定形状运行,该形状在每次标记生成期间就地更新 KV 缓存、输出序列和注意力掩码。通过结合填充、掩码和索引操作,我们避免了过度的图重新编译,从而实现了高效的自回归解码。

KV 缓存优化

LLaMA 使用 KV 缓存实现自回归解码。对于每个生成的标记,KV 缓存存储每个 Transformer 层的注意力键/值激活。因此,在解码新标记时,不再需要重新计算先前标记的键/值。

在 LLaMA 中,KV 缓存张量切片就地更新;这会导致每次生成标记时都发生重新编译事件。为了解决这个问题,我们使用索引张量和 tensor.index_copy() 操作来替换就地切片更新。注意力掩码和输出序列也受益于相同的优化。

输入提示优化

可变长度输入提示在 LLM 应用程序中很常见。此属性会导致输入张量形状动态性,进而导致重新编译事件。当处理提示以填充 KV 缓存时,我们可以 (a) 逐个标记处理输入提示,或 (b) 在一次迭代中处理整个提示。每种方法的优点和缺点是

  1. 预编译 1 个图并逐个标记处理提示
    • 实用:在热身期间编译 1 个图
    • 慢:处理输入提示长度 LO(L) - 长提示的缺点
  2. 预编译所有输入长度范围从 1 到 max_seq_len(例如 2,048)的图
    • 不切实际:在热身时间期间预编译和缓存 max_seq_len 个图
    • 快:1 个图执行来处理完整提示

我们引入了提示长度分桶,这是一种优化,可以在两种替代方案之间取得平衡。我们定义了一组升序桶大小,(b0,b1,b2,…,bB-1),然后根据这些桶值预编译输入大小的程序图,(G0,G1,G2,…,GB-1)B 是桶的数量。对于给定的输入提示,我们将提示长度向上舍入到最接近的桶值 bn,填充序列,并使用 Gn 在一次迭代中处理提示。丢弃对填充标记的计算。对于大于最大桶大小的提示,我们按部分处理它们。

最佳桶大小应由目标应用程序中的提示长度分布确定。在这里,我们采用桶长度:128、256、384、512。任何最多 2,047 个标记的输入提示最多需要 4 个图执行。例如,长度为 256 的 1,500 个输入提示需要 260 个图执行 - 4 个用于处理输入,256 个用于生成输出。

量化

量化减少了表示值所需的位数;它减少了通过集合在多个加速器节点之间通信数据的带宽,并降低了服务特定模型大小的硬件要求。

通常,使用 BF16 权重,175B 参数模型将消耗约 351GB 的内存,因此需要 v4-32 实例来容纳该模型。通过将权重量化为 INT8,我们将模型大小减少了大约 50%,使其可以在较小的 v4-16 实例上运行。由于 LLaMA 对模型激活进行分片,因此量化提供的通信增益可以忽略不计。

在我们的实验中,我们量化了线性层。由于 LLaMA 模型检查点未公开提供,并且我们的目标是评估性能,因此量化模型使用随机权重初始化。最近的文献(例如 AWQ整数还是浮点数?)提供了关于 LLaMA 在各种低位量化方案下的性能属性的见解。

批量大小对量化性能的影响

TPU v4 被编程为在模型批量大小 (BS) > 1 时在矩阵乘法单元 (MXU) 上运行 matmul。对于 BS = 1,matmul 在向量处理器单元 (VPU) 上运行。由于 MXU 比 VPU 更高效,因此 INT8 量化在 BS>1 时会获得性能提升。有关详细信息,请参阅性能分析部分。

操作支持

有时,新模型会引入新的数学运算,这需要 PyTorch/XLA 扩展其支持的操作集以进行编译。对于 LLaMA,我们支持:multinomial

方法论

LLaMA 在 LazyTensorCore 上开箱即用地在 PyTorch/XLA 上工作。我们将此配置用作我们后续分析的基线。所有实验都假设 256 个标记长的输入提示。在没有公开可用的模型检查点的情况下,我们使用随机张量初始化来进行此推理堆栈优化工作。模型检查点预计不会更改此处讨论的延迟结果。

模型大小调整

假设 N 是参数数量,dimensions 是隐藏大小,n_layers 是层数,n_heads 是注意力头的数量,可以使用以下公式来近似模型大小。有关详细信息,请参阅模型概述部分。

N = (dimensions)^2 * n_layers * 12

n_heads 不影响 N,但以下公式适用于开源模型配置。

dim = 128 * n_heads

缓存大小调整

模型参数和 Attention 块中的缓存层都会导致内存消耗。由于默认的 LLaMA 模型使用 BF16 权重,因此本节中的内存消耗计算基于 BF16 权重。

缓存层的大小计算公式为 cache_size = max_batch_size * max_seq_len * dimensionsmax_batch_size = 1max_seq_len = 256 在以下计算中用作示例配置。每个 Attention 块中有 2 个缓存层。因此,总 LLaMA 缓存大小(以字节为单位)为 total_cache_size = n_layers * 2 * cache_size * (2 bytes)

TPU v4 硬件大小调整

每个 TPU v4 芯片都有 32GB 的可用高带宽内存 (HBM)。表 2 详细介绍了内存消耗以及容纳 LLaMA 模型所需的 TPU 芯片数量。

表 2:LLaMA TPU v4 HBM 要求(即 TPU v4 芯片要求)

# 参数 参数 (MB) 缓存 (MB) 总计 (GB) 最少 TPU v4 芯片数
7B 14,000 134 14.128 1
33B 66,000 408 66.41 3
65B 130,000 671 130.67 5
175B 350,000 1,208 351.21 11

指标

以下是衡量推理速度的有用指标。假设 T 是总时间,B 是批量大小,L 是解码序列长度。

延迟定义

延迟是指在目标长度 L 处获得解码结果所需的时间,与批量大小 B 无关。延迟表示用户应等待多长时间才能从生成模型获得响应。

Latency = T (s)

每标记延迟

自回归解码的一个步骤为批量中的每个样本生成一个标记。每标记延迟是该步骤的平均时间。

Per-token latency = T / L (s/token)

吞吐量

吞吐量衡量每单位时间生成的标记数量。虽然它不是评估在线服务的有用指标,但它对于衡量批量处理的速度很有用。

Throughput = B * L / T (tokens/s)

为了最大限度地减少混淆和误解,最好避免使用 T / (B * L) 等指标,因为它混合了延迟和吞吐量。

结果

图 1 显示了 LLaMA 7B 到 175B 模型的延迟/标记结果。在每种情况下,模型都在一系列 TPU v4 配置上运行。例如,LLaMA 7B 在 v4-8 和 v4-16 上分别显示 4.7 毫秒/标记和 3.8 毫秒/标记。有关更多比较,请访问 HuggingFace LLM 性能排行榜

在没有本文中讨论的功能的情况下,在 v4-32 上运行的 LLaMA 65B 提供 120 毫秒/标记,而不是此处获得的 14.5 毫秒/标记,从而实现 8.3 倍 的加速。如前所述,鼓励开发者尝试我们的自定义 torchtorch-xla wheels,它们解锁了此处共享的 LLaMA 推理结果的重现。

Figure 1: LLaMA Inference Performance on TPU v4 hardware

图 1:TPU v4 硬件上的 LLaMA 推理性能

PyTorch/XLA:GPU 性能优于 PyTorch:GPU eager,并且与 PyTorch Inductor 相似。PyTorch/XLA:TPU 性能优于 PyTorch/XLA:GPU。在不久的将来,XLA:GPU 将提供优化,使其与 XLA:TPU 达到同等水平。单个 A100 配置仅适用于 LLaMA 7B,而 8-A100 不适用于 LLaMA 175B。

Figure 2: LLaMA Inference Performance on GPU A100 hardware

图 2:GPU A100 硬件上的 LLaMA 推理性能

随着批量大小的增加,我们观察到每标记延迟的次线性增加,突显了硬件利用率和延迟之间的权衡。

Figure 3: LLaMA Inference Performance across different batch sizes

图 3:不同批量大小下的 LLaMA 推理性能

我们的研究表明,最大序列输入长度 (max_seq_len) 对推理延迟的影响相对较小。我们将此归因于标记生成的顺序和迭代性质。性能上的微小差异可能是由于 KV 缓存访问延迟随存储大小的增加而变化。

Figure 4: LLaMA Inference Performance across different prompt lengths

图 4:不同提示长度下的 LLaMA 推理性能

LLM 通常是内存受限的应用程序;因此,通过量化模型参数,我们可以实现在每个单位时间内在 MXU 上加载和执行更大的张量(即 HBM ⇒ CMEM 和 CMEM ⇒ MXU 数据移动)。图 5 显示 INT8 仅权重量化提供 1.6 倍至 1.9 倍的加速,从而允许在给定硬件上运行更大的模型。

当 BS=1 时,INT8 张量被调度到小于 MXU 的 VPU(请参阅 TPU v4 论文);否则,使用 MXU。因此,当 BS=1 时,量化内存带宽增益被缺乏 MXU 利用率所抵消。但是,当 BS>1 时,内存增益在量化模型上提供卓越的延迟。例如,在 175B 参数 LLaMA 的情况下,具有量化的 v4-16 和没有量化的 v4-32 提供相似的性能。请注意,我们不提供 FP8 比较,因为 PyTorch 尚未提供此数据类型。

Figure 5: LLaMA Inference Performance vs. weight-only quantization. The missing blue bars suggest the model size doesn’t fit in the specified TPU hardware.

图 5:LLaMA 推理性能与仅权重量化。缺失的蓝色条表示模型大小不适合指定的 TPU 硬件。

图 6 展示了随着输入提示长度从 10 个标记增长到 1,500 个标记,PyTorch/XLA 的稳定性能优势。这种强大的扩展能力表明 PyTorch/XLA 重新编译事件最少,从而支持广泛的实际应用。在此实验中,最大长度为 2,048,最大生成长度为 256。

Figure 6: LLaMA Inference Performance vs. Input Prompt Length

图 6:LLaMA 推理性能与输入提示长度

最终想法

我们对 PyTorch/XLA 的未来感到非常兴奋,并邀请社区加入我们。PyTorch/XLA 完全在开源中开发。因此,请在 GitHub 上提交问题、提交拉取请求和发送 RFC,以便我们可以公开协作。您还可以亲自试用 PyTorch/XLA 在各种 XLA 设备(包括 TPU 和 GPU)上的表现。

致谢,
Google 的 PyTorch/XLA 团队
#PoweredByPyTorch