背景与最先进技术
在自然语言处理 (NLP) 领域,语言模型旨在利用过去的输入词元序列生成一个词元(例如,单词)。大型语言模型 (LLM) 是该领域最新的深度学习创新,旨在以类人的方式生成文本。这些模型通常使用Transformer来改进它们在大量输入词元上的注意力。
LLaMA 由 Meta AI 开源,是一个强大的基础 LLM,经过超过 1T 词元的训练。LLaMA 与许多一流模型(如 GPT-3、Chinchilla、PaLM)具有竞争力。LLaMA (13B) 优于 GPT-3 (175B),这突出显示了它从每个模型参数中提取更多计算量的能力。
在这篇博文中,我们使用 LLaMA 作为示例模型来演示 PyTorch/XLA 在 LLM 推理方面的能力。我们讨论了本文中讨论的计算技术和优化如何将由 Google Cloud TPU v4 (v4-16) 提供支持的 65B 参数 LLaMA 模型的推理延迟缩短 6.4 倍。
模型概述
我们展示了 PyTorch/XLA 在 Meta 最新 LLM LLaMA 上的性能。我们展示了一系列常见 LLaMA 配置上的性能优化。请注意,175B 参数模型配置不在公共领域。对于下面提到的 175B 参数模型,我们将 OPT 175B 模型配置 应用于 LLaMA 代码库。除非另有说明,在所有配置中,我们都使用 max_seq_len=256 和 dtype=bfloat16 作为权重和激活。
表 1:本文探讨的模型配置
| LLaMA | 模型超参数 | |||
| # 参数 | 维度 | 头数 | 层数 | 最大序列长度 |
| 7B | 4,096 | 32 | 32 | 256 |
| 33B | 6,656 | 52 | 60 | 256 |
| 65B | 8,192 | 64 | 80 | 256 |
| 175B | 12,288 | 96 | 96 | 256 |
LLM 的性能挑战
LLM 具有一些特性,这些特性使其对编译器优化构成挑战。(a) LLM 使用自回归解码根据前一个词元生成下一个词元;这意味着提示张量和缓存具有动态形状。(b) LLM 必须处理可变输入提示长度,而不会因输入张量形状变化而触发重新编译;输入张量必须经过适当的分桶和填充以避免重新编译。(c) LLM 通常需要比单个 TPU(或 GPU)设备所能支持的内存更多。需要模型分片方案才能将模型分布到分布式计算架构中。例如,一个具有 65B 参数的 LLaMA 模型可以安装在 v4-16 Cloud TPU 上,这相当于 8 个 A100 GPU。(d) 在生产环境中运行 LLM 可能很昂贵;提高每总拥有成本性能 (Perf/TCO) 的一种方法是通过量化;量化可以潜在地降低硬件要求。
PyTorch/XLA 中的推理技术栈
我们的目标是为 AI 社区提供高性能推理栈。PyTorch/XLA 与 TorchDynamo、PjRt、OpenXLA 以及各种模型并行方案集成。TorchDynamo 消除了运行时的追踪开销,PjRt 实现了高效的主机-设备通信;PyTorch/XLA 可追踪的集合通信通过 TorchDynamo 在 LLaMA 上实现了模型和数据并行。要尝试我们的结果,请使用我们的自定义 torch、torch-xla 轮子来重现我们的 LLaMA 推理解决方案。PyTorch/XLA 2.1 将默认支持本文讨论的功能。
并行计算
FairScale 分片
LLaMA 使用 FairScale 模型分片 API(fairscale.nn.model_parallel.layers)。我们使用 PyTorch/XLA 通信集合 (CC) 操作(例如 all-reduce)构建了此 API 的等效表示,以在加速器之间通信程序状态(例如激活)。TorchDynamo 目前不完全支持捕获 CC 操作(又称 可追踪集合通信)。如果没有此支持,TorchDynamo FX 图将在每次设备通信(即每个模型层)处被切断。图切断会导致性能损失,因为底层 XLA 编译器会失去完全图优化机会。为了解决这个问题,我们通过将调度器集合通信集成到我们现有的 CC API 中来提供 PyTorch/XLA 可追踪集合通信。不同之处在于,鉴于 PyTorch/XLA 的惰性执行特性,我们不需要在集合通信后插入 c10d.wait() 操作。通过对可追踪集合通信的支持,PyTorch/XLA 允许在 TorchDynamo 中生成单一 FX 图。
PyTorch/XLA 上的自回归解码
LLM 需要自回归解码,将前一个单词作为提示来预测下一个词元。自回归解码会导致无边界的动态形状问题,进而导致每个提示的重新编译。我们优化了 LLaMA 自回归解码器,使其在每个词元生成过程中以固定形状操作,从而原地更新 KV 缓存、输出序列和注意力掩码。通过填充、掩码和索引操作的组合,我们避免了过多的图重新编译,从而实现了高效的自回归解码。
KV 缓存优化
LLaMA 使用 KV 缓存实现自回归解码。对于每个生成的词元,KV 缓存存储每个 Transformer 层的注意力键/值激活。因此,在解码新词元时,不再需要重新计算先前词元的键/值。
在 LLaMA 中,KV 缓存张量切片是原地更新的;这导致每次生成词元时都会发生重新编译事件。为了解决这个问题,我们使用索引张量和 tensor.index_copy() 操作来替换原地切片更新。注意力掩码和输出序列也受益于相同的优化。
输入提示优化
可变长度输入提示在 LLM 应用程序中很常见。此属性导致输入张量形状动态性,进而导致重新编译事件。在处理提示以填充 KV 缓存时,我们(a)逐个词元处理输入提示,或者(b)一次迭代处理整个提示。每种方法的优缺点如下:
- 预编译 1 个图并逐个词元处理提示
- 实用:在热身期间编译 1 个图
- 慢:处理输入提示长度 L 需要 O(L) 时间——对于长提示来说是一个缺点
- 预编译所有输入长度从 1 到 max_seq_len(例如 2,048)的图
- 不实用:在热身期间预编译和缓存 max_seq_len 个图
- 快:1 次图执行即可处理完整的提示
我们引入了提示长度分桶,这是一种在两种替代方案之间取得平衡的优化。我们定义了一组升序桶大小,(b0,b1,b2,…,bB-1),然后根据这些桶值预编译具有输入大小的程序图,(G0,G1,G2,…,GB-1);B 是桶的数量。对于给定的输入提示,我们将提示长度向上舍入到最接近的桶值 bn,填充序列,并使用 Gn 在一次迭代中处理提示。填充词元上的计算被丢弃。对于大于最大桶大小的提示,我们分段处理它们。
最佳桶大小应由目标应用程序中的提示长度分布决定。在这里,我们采用的桶长度为:128、256、384、512。任何多达 2,047 个词元的输入提示最多需要 4 次图执行。例如,一个具有 256 生成长度的 1,500 输入提示需要 260 次图执行——4 次用于处理输入,256 次用于生成输出。
量化
量化减少了表示值所需的位数;它减少了在多个加速器节点之间(通过集合通信)通信数据的带宽,并降低了服务特定模型大小的硬件要求。
通常,对于 BF16 权重,一个 175B 参数模型将消耗大约 351GB 内存,因此需要 v4-32 实例来容纳该模型。通过将权重量化为 INT8,我们将模型大小减少了大约 50%,使其可以在较小的 v4-16 实例上运行。由于 LLaMA 分片模型激活,量化提供的通信增益微不足道。
在我们的实验中,我们对线性层进行了量化。由于 LLaMA 模型检查点未公开,且我们的目标是评估性能,因此量化模型使用随机权重进行初始化。最近的文献,例如 AWQ 和 整数还是浮点数?,提供了关于 LLaMA 在各种低位量化方案下的性能特性的见解。
批处理大小对量化性能的影响
TPU v4 被编程为在模型批处理大小 (BS) > 1 时在矩阵乘法单元 (MXU) 上运行 matmul。对于 BS = 1,matmul 在向量处理器单元 (VPU) 上运行。由于 MXU 比 VPU 更高效,因此 INT8 量化在 BS > 1 时获得性能提升。有关详细信息,请参阅 性能分析 部分。
操作支持
偶尔,新模型会引入新的数学运算,这需要 PyTorch/XLA 扩展其支持的编译操作集。对于 LLaMA,我们支持:多项式分布。
方法论
LLaMA 在 LazyTensorCore 上的 PyTorch/XLA 中开箱即用。我们使用此配置作为后续分析的基准。所有实验均假设 256 长输入提示。在没有公开可用模型检查点的情况下,我们使用随机张量初始化来完成此推理栈优化工作。模型检查点预计不会改变此处讨论的延迟结果。
模型大小
假设 N 是参数数量,dimensions 是隐藏大小,n_layers 是层数,n_heads 是注意力头数,则可以使用以下方程来估算模型大小。有关详细信息,请参见模型概述部分。
N = (dimensions)^2 * n_layers * 12
n_heads 不影响 N,但以下方程适用于开源模型配置。
dim = 128 * n_heads
缓存大小
注意力块中的模型参数和缓存层都占用内存。由于默认的 LLaMA 模型使用 BF16 权重,本节中的内存消耗计算基于 BF16 权重。
缓存层的大小由 cache_size = max_batch_size * max_seq_len * dimensions 计算。在以下计算中,max_batch_size = 1 和 max_seq_len = 256 用作示例配置。每个注意力块中有 2 个缓存层。因此,LLaMA 的总缓存大小(以字节为单位)为 total_cache_size = n_layers * 2 * cache_size * (2 bytes)。
TPU v4 硬件尺寸
每个 TPU v4 芯片都有 32GB 可用高带宽内存 (HBM)。表 2 详细说明了内存消耗和保存 LLaMA 模型所需的 TPU 芯片数量。
表 2:LLaMA TPU v4 HBM 要求(即 TPU v4 芯片要求)
| # 参数 | 参数 (MB) | 缓存 (MB) | 总计 (GB) | 所需 TPU v4 芯片最少数量 |
| 7B | 14,000 | 134 | 14.128 | 1 |
| 33B | 66,000 | 408 | 66.41 | 3 |
| 65B | 130,000 | 671 | 130.67 | 5 |
| 175B | 350,000 | 1,208 | 351.21 | 11 |
指标
以下是衡量推理速度的有用指标。假设 T 是总时间,B 是批处理大小,L 是解码序列长度。
延迟定义
延迟是指在目标长度 L 下获取解码结果所需的时间,与批处理大小 B 无关。延迟表示用户需要等待多长时间才能从生成模型获得响应。
Latency = T (s)
每词元延迟
自回归解码的一个步骤会为批处理中的每个样本生成一个词元。每词元延迟是该一步的平均时间。
Per-token latency = T / L (s/token)
吞吐量
吞吐量衡量单位时间内生成的词元数量。虽然它对于评估在线服务不是一个有用的指标,但它对于衡量批处理速度很有用。
Throughput = B * L / T (tokens/s)
为了尽量减少混淆和误解,最好避免使用 T / (B * L) 等混合了延迟和吞吐量的指标。
结果
图 1 显示了 LLaMA 7B 到 175B 模型的每词元延迟结果。在每种情况下,模型都在一系列 TPU v4 配置上运行。例如,LLaMA 7B 在 v4-8 和 v4-16 上分别显示 4.7 毫秒/词元和 3.8 毫秒/词元。更多比较,请访问 HuggingFace LLM 性能排行榜。
如果没有本文讨论的功能,在 v4-32 上运行的 LLaMA 65B 将提供 120ms/词元的延迟,而不是此处获得的 14.5ms/词元,从而实现 8.3 倍 的加速。如前所述,鼓励开发人员尝试我们的自定义 torch、torch-xla 轮子,这些轮子可以重现此处共享的 LLaMA 推理 结果。
图 1:TPU v4 硬件上的 LLaMA 推理性能
PyTorch/XLA:GPU 性能优于 PyTorch:GPU eager 模式,与 PyTorch Inductor 类似。PyTorch/XLA:TPU 性能优于 PyTorch/XLA:GPU。在不久的将来,XLA:GPU 将提供与 XLA:TPU 相当的优化。单个 A100 配置仅适用于 LLaMA 7B,8-A100 不适用于 LLaMA 175B。
图 2:GPU A100 硬件上的 LLaMA 推理性能
随着批处理大小的增加,我们观察到每词元延迟呈次线性增长,突出了硬件利用率和延迟之间的权衡。
图 3:不同批处理大小下的 LLaMA 推理性能
我们的研究表明,最大序列输入长度 (max_seq_len) 对推理延迟的影响相对较小。我们将其归因于词元生成的顺序和迭代性质。性能上的微小差异可能是由于存储大小增加导致 KV 缓存访问延迟变化。
图 4:不同提示长度下的 LLaMA 推理性能
LLM 通常是内存受限的应用程序;因此,通过量化模型参数,我们能够以单位时间在 MXU 上加载和执行更大的张量(即 HBM => CMEM 和 CMEM => MXU 数据移动)。图 5 显示,INT8 仅权重量化提供了 1.6 倍至 1.9 倍的加速,从而可以在给定硬件上运行更大的模型。
当 BS=1 时,INT8 张量被调度到 VPU,VPU 小于 MXU(参见 TPU v4 论文);否则,使用 MXU。因此,当 BS=1 时,量化内存带宽增益被 MXU 利用率不足所抵消。然而,当 BS>1 时,内存增益在量化模型上提供了卓越的延迟。例如,在 175B 参数 LLaMA 的情况下,带量化的 v4-16 和不带量化的 v4-32 提供相似的性能。请注意,我们不提供 FP8 比较,因为 PyTorch 尚未提供此数据类型。
图 5:LLaMA 推理性能与仅权重量化。缺失的蓝色条表示模型大小不适合指定的 TPU 硬件。
图 6 展示了 PyTorch/XLA 随着输入提示长度从 10 个词元增长到 1,500 个词元而保持的稳定性能优势。这种强大的扩展能力表明 PyTorch/XLA 重新编译事件最少,从而支持广泛的实际应用。在本实验中,最大长度为 2,048,最大生成长度为 256。
图 6:LLaMA 推理性能与输入提示长度
最终想法
我们对 PyTorch/XLA 的未来感到兴奋,并邀请社区加入我们。PyTorch/XLA 完全以开源形式开发。因此,请在 GitHub 上提交问题、拉取请求和发送 RFC,以便我们公开协作。您还可以 亲自尝试 在各种 XLA 设备(包括 TPU 和 GPU)上使用 PyTorch/XLA。
祝好,
Google 的 PyTorch/XLA 团队
#PoweredByPyTorch