作者:Milad Mohammadi, Jiewen Tan, Liyang Lu, Siyuan Liu, Yeounoh Chung, Wonjoo Lee, Manfei Bai, Steven Krawczyk, Shauheen Zahirazami, Alex Wertheim, Meghan Cowan, Jack Cao, Joe Spisak

背景与现有技术

在自然语言处理 (NLP) 领域,语言模型旨在利用过去的一系列输入 token 来生成 token(例如单词)。大型语言模型 (LLM) 是该领域的最新深度学习创新成果,用于生成类似人类的文本。这些模型通常使用Transformer 来提高它们对大序列输入 token 的注意力。

Meta AI 开源的LLaMA 是一个强大的基础 LLM,在超过 1 万亿个 token 上进行训练。LLaMA 与许多一流模型具有竞争力,例如GPT-3ChinchillaPaLMLLaMA (13B) 在性能上优于 GPT-3 (175B),突显了其从每个模型参数中提取更多计算能力的能力。

在这篇博文中,我们以 LLaMA 为例,演示 PyTorch/XLA 在 LLM 推理方面的能力。我们讨论了此处讨论的计算技术和优化如何在使用 Google Cloud TPU v4 (v4-16) 的 65B 参数 LLaMA 模型上将推理速度提升 6.4 倍,从而降低推理延迟。

模型概述

我们展示了 PyTorch/XLA 在 Meta 的最新 LLM LLaMA 上的性能能力。我们在一系列常见的 LLaMA 配置上展示了性能优化。请注意,公开领域中没有 175B 参数模型的配置。对于下面提到的 175B 参数模型,我们将OPT 175B 模型配置应用于 LLaMA 代码库。除非另有说明,在所有配置中,我们对权重和激活使用max_seq_len=256dtype=bfloat16

表 1:本文探索的模型配置

LLaMA 模型超参数
参数数量 维度 头数量 层数量 最大序列长度
7B 4,096 32 32 256
33B 6,656 52 60 256
65B 8,192 64 80 256
175B 12,288 96 96 256

LLM 的性能挑战

LLM 具有一些属性,使得编译器优化面临挑战。(a) LLM 使用自回归解码根据前一个 token 生成下一个 token;这意味着 prompt 张量和缓存具有动态形状。(b) LLM 必须处理可变长度的输入 prompt,同时避免因输入张量形状变化而触发重新编译;必须对输入张量进行适当的分桶和填充以避免重新编译。(c) LLM 通常需要比单个 TPU(或 GPU)设备更多的内存。需要模型分片方案才能将模型适应分布式计算架构。例如,具有 65B 参数的 LLaMA 模型可以适应 v4-16 Cloud TPU,这与 8 个 A100 GPU 可比。(d) 在生产环境中运行 LLM 可能成本高昂;提高总拥有成本性能比 (Perf/TCO) 的一种方法是通过量化;量化可以潜在地降低硬件要求。

PyTorch/XLA 中的推理技术栈

我们的目标是为 AI 社区提供一个高性能的推理技术栈。PyTorch/XLA 集成了 TorchDynamoPjRtOpenXLA 以及各种模型并行方案。TorchDynamo 消除了运行时的跟踪开销,PjRt 实现了高效的宿主-设备通信;PyTorch/XLA 可追踪集合通信 (traceable collectives) 通过 TorchDynamo 在 LLaMA 上实现了模型并行和数据并行。要尝试我们的结果,请使用我们的自定义 torchtorch-xla wheel 包来重现此处分享的 LLaMA 推理解决方案。PyTorch/XLA 2.1 将默认支持本文讨论的功能。

并行计算

FairScale 分片

LLaMA 使用 FairScale 模型分片 API (fairscale.nn.model_parallel.layers)。我们使用 PyTorch/XLA 集合通信 (CC) 操作(例如 all-reduce)构建了此 API 的等效表示,以在加速器之间通信程序状态(例如激活)。TorchDynamo 目前不完全支持捕获 CC 操作(也称为可追踪集合通信)。没有此支持,TorchDynamo FX 图将在每次设备通信(即每个模型层)处被切断。图切断会导致性能损失,因为底层的 XLA 编译器失去了全面的图优化机会。为了解决这个问题,我们通过将调度器集合通信集成到我们现有的 CC API 中来提供 PyTorch/XLA 可追踪集合通信。不同之处在于,鉴于 PyTorch/XLA 的惰性执行特性,我们不需要在集合通信后插入 c10d.wait() 操作。通过支持可追踪集合通信,PyTorch/XLA 允许在 TorchDynamo 中生成单一 FX 图。

PyTorch/XLA 上的自回归解码

LLM 需要自回归解码才能将前一个单词作为 prompt 输入以预测下一个 token。自回归解码会导致无界的动态形状问题,进而导致每个 prompt 都需要重新编译。我们优化了 LLaMA 自回归解码器,使其以固定形状运行,并在每次生成 token 时原地更新 KV 缓存、输出序列和注意力掩码。通过结合使用填充 (padding)、掩码 (masking) 和索引操作 (index ops),我们避免了过度的图重新编译,从而实现了高效的自回归解码。

KV 缓存优化

LLaMA 使用 KV 缓存实现了自回归解码。对于每个生成的 token,KV 缓存存储每个 Transformer 层的注意力 key/value 激活。因此,在解码新 token 时,不再需要重新计算先前 token 的 key/value。

在 LLaMA 中,KV 缓存张量切片是原地更新的;这导致每次生成 token 时都会发生重新编译事件。为了解决这个问题,我们使用索引张量和 tensor.index_copy() 操作来替换原地切片更新。注意力掩码和输出序列也受益于相同的优化。

输入 Prompt 优化

可变长度输入 prompt 在 LLM 应用中很常见。此属性会导致输入张量形状动态性,进而引发重新编译事件。在处理 prompt 以填充 KV 缓存时,我们可以选择 (a) 逐 token 处理输入 prompt,或 (b) 在一次迭代中处理整个 prompt。每种方法的优缺点如下:

  1. 预编译 1 个图并逐 token 处理 prompt
    • 实用:在预热期间编译 1 个图
    • 慢:处理长度为 L 的输入 prompt 需要 O(L) 时间 - 对于长 prompt 是一个缺点
  2. 预编译所有输入长度从 1 到 max_seq_len(例如 2,048)的图
    • 不实用:在预热期间预编译和缓存 max_seq_len 个图
    • 快:执行 1 个图处理完整 prompt

我们引入了 prompt 长度分桶 (prompt length bucketization),这是一种在上述两种方案之间取得平衡的优化方法。我们定义了一组递增的桶大小,(b0,b1,b2,…,bB-1),然后根据这些桶值预编译输入大小对应的程序图,(G0,G1,G2,…,GB-1)B 是桶的数量。对于给定的输入 prompt,我们将 prompt 长度向上取整到最接近的桶值 bn,填充序列,然后使用 Gn 在一次迭代中处理 prompt。对填充 token 的计算将被丢弃。对于大于最大桶大小的 prompt,我们分段处理。

最佳的桶大小应由目标应用中的 prompt 长度分布决定。在此,我们采用的桶长度为:128、256、384、512。任何最多 2,047 个 token 的输入 prompt 最多需要执行 4 次图。例如,一个包含 1,500 个输入 token 且生成长度为 256 的 prompt 需要执行 260 次图 - 4 次用于处理输入,256 次用于生成输出。

量化

量化减少了表示一个值所需的比特数;它减少了跨多个加速器节点(通过集合通信)通信数据的带宽,并降低了服务特定模型大小所需的硬件要求。

通常,对于 BF16 权重的 175B 参数模型,大约会消耗 351GB 内存,因此需要一个 v4-32 实例来容纳模型。通过将权重量化到 INT8,我们将模型大小减小了大约 50%,使其能够在更小的 v4-16 实例上运行。由于 LLaMA 分片模型激活,量化在通信方面带来的收益可以忽略不计。

在我们的实验中,我们对线性层进行了量化。由于 LLaMA 模型 checkpoint 未公开可用,并且我们的目标是评估性能,因此量化模型使用随机权重进行初始化。最近的文献,例如 AWQInteger or Floating Point?,提供了 LLaMA 在各种低比特量化方案下的性能特性见解。

批量大小对量化性能的影响

当模型批量大小 (BS) > 1 时,TPU v4 被编程为在矩阵乘法单元 (MXU) 上运行 matmul。当 BS = 1 时,matmul 在向量处理器单元 (VPU) 上运行。由于 MXU 比 VPU 更高效,因此在 BS>1 时,INT8 量化能带来性能提升。详情请参阅性能分析部分。

操作支持

有时,新模型会引入新的数学运算,需要 PyTorch/XLA 扩展其支持的操作集以便进行编译。对于 LLaMA,我们支持了:multinomial

方法论

LLaMA 在 LazyTensorCore 上开箱即用地支持 PyTorch/XLA。我们将此配置用作后续分析的基准。所有实验均假设输入 prompt 长度为 256。由于没有公开可用的模型 checkpoint,我们在进行此推理技术栈优化工作时使用了随机张量初始化。模型 checkpoint 预计不会改变此处讨论的延迟结果。

模型大小计算

假设 N 是参数数量,dimensions 是隐藏层大小,n_layers 是层数量,n_heads 是注意力头数量,可以使用下面的公式来近似计算模型大小。详情请参阅模型概述部分。

N = (dimensions)^2 * n_layers * 12

n_heads 不影响 N,但对于开源模型配置,以下等式成立。

dim = 128 * n_heads

缓存大小计算

模型参数和注意力块中的缓存层都会消耗内存。由于默认的 LLaMA 模型使用 BF16 权重,本节中的内存消耗计算基于 BF16 权重。

缓存层的大小计算公式为 cache_size = max_batch_size * max_seq_len * dimensionsmax_batch_size = 1max_seq_len = 256 在以下计算中用作示例配置。每个注意力块中有 2 个缓存层。因此,LLaMA 的总缓存大小(以字节为单位)为 total_cache_size = n_layers * 2 * cache_size * (2 bytes)

TPU v4 硬件大小计算

每个 TPU v4 芯片拥有 32GB 可用高带宽内存 (HBM)。表 2 详细说明了内存消耗以及容纳 LLaMA 模型所需的 TPU 芯片数量。

表 2:LLaMA TPU v4 HBM 需求(即 TPU v4 芯片需求)

参数数量 参数 (MB) 缓存 (MB) 总计 (GB) 最少 TPU v4 芯片数量
7B 14,000 134 14.128 1
33B 66,000 408 66.41 3
65B 130,000 671 130.67 5
175B 350,000 1,208 351.21 11

指标

以下是衡量推理速度的有用指标。假设 T 是总时间,B 是批量大小,L 是解码序列长度。

延迟定义

延迟是指获取目标长度 L 的解码结果所需的时间,与批量大小 B 无关。延迟表示用户需要等待多长时间才能从生成模型获得响应。

Latency = T (s)

每 token 延迟

自回归解码的一个步骤会为批量中的每个样本生成一个 token。每 token 延迟是该步骤的平均时间。

Per-token latency = T / L (s/token)

吞吐量

吞吐量衡量单位时间内生成多少个 token。虽然它不是评估在线服务的有用指标,但它对于衡量批量处理的速度很有用。

Throughput = B * L / T (tokens/s)

为了最大程度地减少混淆和误解,最好避免使用诸如 T / (B * L) 这样的指标,它混合了延迟和吞吐量。

结果

图 1 显示了 LLaMA 7B 到 175B 模型的每 token 延迟结果。在每种情况下,模型都在一系列 TPU v4 配置上运行。例如,LLaMA 7B 在 v4-8 和 v4-16 上的每 token 延迟分别为 4.7ms 和 3.8ms。更多比较,请访问 HuggingFace LLM 性能排行榜

在没有本文讨论的功能的情况下,运行在 v4-32 上的 LLaMA 65B 的每 token 延迟为 120ms,而不是此处获得的 14.5ms,从而实现了 8.3 倍的加速。如前所述,我们鼓励开发者尝试我们的自定义 torchtorch-xla wheel 包,这些包解锁了此处分享的 LLaMA 推理结果的重现。

Figure 1: LLaMA Inference Performance on TPU v4 hardware

图 1:LLaMA 在 TPU v4 硬件上的推理性能

PyTorch/XLA:GPU 的性能优于 PyTorch:GPU eager 模式,并与 PyTorch Inductor 类似。PyTorch/XLA:TPU 的性能优于 PyTorch/XLA:GPU。在不久的将来,XLA:GPU 将提供优化,使其与 XLA:TPU 的性能相当。单个 A100 配置仅适用于 LLaMA 7B,而 8-A100 配置不适用于 LLaMA 175B。

Figure 2: LLaMA Inference Performance on GPU A100 hardware

图 2:LLaMA 在 GPU A100 硬件上的推理性能

随着批量大小的增加,我们观察到每 token 延迟的次线性增长 (sublinear increase),这突显了硬件利用率和延迟之间的权衡。

Figure 3: LLaMA Inference Performance across different batch sizes

图 3:不同批量大小下 LLaMA 的推理性能

我们的研究表明,最大序列输入长度 (max_seq_len) 对推理延迟的影响相对较小。我们将其归因于 token 生成的顺序性和迭代性。性能上的微小差异可能是由于存储大小增加时 KV 缓存访问延迟发生变化所致。

Figure 4: LLaMA Inference Performance across different prompt lengths

图 4:不同 prompt 长度下 LLaMA 的推理性能

LLM 通常是内存瓶颈应用;因此,通过量化模型参数,我们可以在单位时间内在 MXU 上加载和执行更大的张量(即 HBM ⇒ CMEM 和 CMEM ⇒ MXU 数据移动)。图 5 显示,INT8 仅权重 (weight-only) 量化提供了 1.6 到 1.9 倍的加速,从而允许在给定硬件上运行更大的模型。

当 BS=1 时,INT8 张量被分派到 VPU,VPU 小于 MXU(参见TPU v4 论文);否则使用 MXU。因此,当 BS=1 时,量化带来的内存带宽增益被 MXU 利用率不足所抵消。然而,当 BS>1 时,内存增益在量化模型上带来了更好的延迟。例如,在 175B 参数 LLaMA 的情况下,带有量化的 v4-16 和没有量化的 v4-32 提供了相似的性能。请注意,我们没有提供 FP8 比较,因为 PyTorch 尚未提供此数据类型。

Figure 5: LLaMA Inference Performance vs. weight-only quantization. The missing blue bars suggest the model size doesn’t fit in the specified TPU hardware.

图 5:LLaMA 推理性能对比仅权重 (weight-only) 量化。缺失的蓝色条表示模型大小无法在指定的 TPU 硬件中容纳。

图 6 展示了 PyTorch/XLA 在输入 prompt 长度从 10 个 token 增长到 1,500 个 token 时所表现出的稳定性能优势。这种强大的扩展能力表明 PyTorch/XLA 的重新编译事件很少,从而支持广泛的实际应用。在此实验中,最大长度为 2,048,最大生成长度为 256。

Figure 6: LLaMA Inference Performance vs. Input Prompt Length

图 6:LLaMA 推理性能对比输入 Prompt 长度

总结思考

我们对 PyTorch/XLA 的未来感到兴奋,并邀请社区加入我们。PyTorch/XLA 完全在开源中开发。因此,请在 GitHub 上提交 issue、发送 pull request 和 RFC,以便我们公开协作。您也可以在包括 TPU 和 GPU 在内的各种 XLA 设备上亲自试用 PyTorch/XLA。

祝好,
Google PyTorch/XLA 团队
#由PyTorch驱动