作者:Jiewen Tan, Jon Bolin, Yeounoh Chung, Liyang Lu, Siyuan Liu, Wonjoo Lee, Manfei Bai, Meghan Cowan, Jack Cao, Milad Mohammadi, Shauheen Zahirazami, Alex Spiridonov

在 AI 创新以前所未有的速度加速发展的背景下,Meta 开源的大型语言模型 (LLM) 系列 Llama 脱颖而出,成为一项显著的突破。Llama 标志着 LLM 向前迈出了重要一步,展示了预训练架构在广泛应用中的强大能力。Llama 2 进一步拓展了规模和能力的界限,在语言理解、生成及其他领域激发了新的进展。

在 Llama 发布后不久,我们发表了一篇博客文章,展示了如何在 Cloud TPU v4 上使用 PyTorch/XLA 为 Llama 实现超低推理延迟。基于这些成果,今天我们很荣幸分享 Llama 2 在 Cloud TPU v4 以及我们最新的 AI 超级计算机 Cloud TPU v5e 上使用 PyTorch/XLA 的训练和推理性能。

在这篇博客文章中,我们以 Llama 2 为例,展示了 PyTorch/XLA 在 Cloud TPU 上进行 LLM 训练和推理的强大能力。我们讨论了用于提高推理吞吐量和训练模型 FLOPs 利用率 (MFU) 的计算技术和优化方法。对于 Llama 2 70B 参数模型,我们使用 PyTorch/XLA 在 Google Cloud TPU 上实现了 53% 的训练 MFU、17 ms/token 的推理延迟以及 42 tokens/s/chip 的吞吐量。 我们提供了训练用户指南推理用户指南,以便复现本文中的结果。此外,您可以在此处找到我们在 Google Next 2023 的演示文稿

模型概述

Llama 2 有多种尺寸,参数量从 7B 到 70B 不等,可满足不同的需求、计算资源以及训练/推理预算。无论是小型项目还是大规模部署,Llama 模型都提供了多功能性和可扩展性,以适应广泛的应用。

Llama 2 是一个自回归语言模型,采用优化的 Transformer 架构。最大的 70B 模型使用了分组查询注意力 (grouped-query attention),这在不牺牲质量的情况下加快了推理速度。Llama 2 在 2 万亿 token(比 Llama 多 40% 的数据)上进行训练,推理时的上下文长度为 4,096 token(是 Llama 上下文长度的两倍),这使得模型更加准确、流畅且富有创造力。

Llama 2 是一个最先进的 LLM,在推理、编码、熟练度以及知识测试等许多基准测试中,其性能优于许多其他开源语言模型。模型的规模和复杂性对 AI 加速器提出了诸多要求,使其成为衡量 PyTorch/XLA 在 Cloud TPU 上进行 LLM 训练和推理性能的理想基准。

LLM 的性能挑战

对于 Llama 2 等 LLM 进行大规模分布式训练带来了技术挑战,需要实用的解决方案才能最有效地利用 TPU。Llama 的规模会占用 TPU 的内存和处理资源。为解决此问题,我们使用模型分片 (model sharding),即将模型分解为更小的片段,每个片段都适合单个 TPU 核心的容量。这实现了跨多个 TPU 的并行性,在提高训练速度的同时减少了通信开销。

另一个挑战是如何高效管理训练 Llama 2 所需的大型数据集,这需要有效的数据分发和同步方法。此外,优化分布式 TPU 上的学习率调度、梯度聚合以及权重同步等因素对于实现收敛至关重要。

在对 Llama 2 进行预训练或微调后,在该模型 checkpoint 上运行推理会带来额外的技术挑战。我们之前博客文章中讨论的所有挑战,例如自回归解码、变长的输入 prompt 以及模型分片和量化的需求,对于 Llama 2 仍然适用。此外,Llama 2 引入了两个新功能:分组查询注意力 (grouped-query attention) 和提前停止 (early stopping)(当所有 prompt 都到达 eos token 时)。我们讨论了 PyTorch/XLA 如何处理这些挑战,以实现在 Cloud TPU v4 和 v5e 上对 Llama 2 进行高性能、高成本效益的训练和推理。

大规模分布式训练

PyTorch/XLA 提供两种主要的大规模分布式训练方法:SPMD(单程序多数据),它利用 XLA 编译器将单设备程序转换并分割为多设备分布式程序;以及 FSDP,它实现了广泛采用的全分片数据并行 (Fully Sharded Data Parallel) 算法。

在这篇博客文章中,我们展示了如何使用 SPMD API 来标注 HuggingFace (HF) Llama 2 实现,以最大化性能。作为比较,我们还展示了相同配置下的 FSDP 结果;在此处阅读关于PyTorch/XLA FSDP API 的信息。

SPMD 概述

让我们简要回顾一下 SPMD 的基础知识。有关详细信息,请参阅我们的博客文章用户指南

网格 (Mesh)

描述 TPU 设备逻辑拓扑的多维数组

# Assuming you are running on a TPU host that has 8 devices attached
num_devices = xr.global_runtime_device_count()
# mesh shape will be (4,2) in this example
mesh_shape = (num_devices // 2, 2)
device_ids = np.array(range(num_devices))
# axis_names 'x' and 'y' are optional
mesh = Mesh(device_ids, mesh_shape, ('x', 'y'))

分区规范 (Partition Spec)

一个元组,描述相应的张量维度如何在网格上分片

partition_spec = ('x', 'y')

标注分片 (Mark Sharding)

一个接受网格和分区规范并为 XLA 编译器生成分片标注的 API。

tensor = torch.randn(4, 4).to('xla')
# Let's resue the above mesh and partition_spec.
# It means the tensor's 0th dim is sharded 4 way and 1th dim is sharded 2 way.
xs.mark_sharding(tensor, mesh, partition_spec)

使用 SPMD 进行 2D 分片

在我们的SPMD 博客文章中,我们演示了如何使用 1D FSDP 风格的分片。在这里,我们介绍一种更强大的分片策略,称为2D 分片,其中参数和激活都进行分片。这种新的分片策略不仅可以容纳更大的模型,还能将 MFU 提高到最高 54.3%。更多详情,请阅读性能基准部分。

本节介绍一套适用于大多数 LLM 的通用规则,为方便起见,我们直接引用 HF Llama 中的变量名和配置名。

首先,让我们创建一个具有相应轴名称:data 和 model 的 2D 网格。data 轴通常用于分发输入数据,model 轴则用于进一步分发模型。

mesh = Mesh(device_ids, mesh_shape, ('data', 'model'))

mesh_shape 可以是一个超参数,根据不同的模型大小和硬件配置进行调整。相同的网格将在所有后续分片标注中重复使用。在接下来的几个部分中,我们将介绍如何使用网格对参数、激活和输入数据进行分片。

参数分片

下表总结了 HF Llama 2 的所有参数及相应的分区规范。HF 示例代码可以在此处找到。

参数名称 说明 参数形状 分区规范 (Partition Spec)
embed_tokens 嵌入层 (vocab_size, hidden_size) (model, data)
q_proj 注意力权重 (num_heads x head_dim, hidden_size) (data, model)
k_proj / v_proj 注意力权重 (num_key_value_heads x head_dim, hidden_size) (data, model)
o_proj 注意力权重 (hidden_size, num_heads x head_dim) (model, data)
gate_proj / up_proj MLP 权重 (intermediate_size, hidden_size) (model, data)
down_proj MLP 权重 (hidden_size, intermediate_size) (data, model)
lm_head HF 输出嵌入 (vocab_size, hidden_size) (model, data)

表 1:SPMD 2D 分片参数分区规范

规则是,将除 QKVO 投影之外的任何权重的 hidden_size 维度根据网格的 data 轴进行分片,然后将另一个维度根据剩余的 model 轴进行分片。对于 QKVO,则反向操作。这种 model-data 轴轮换方法类似于Megatron-LM 的方法,以减少通信开销。对于 layernorm 权重,由于它们是 1D 张量,我们隐式地将其标记为在不同设备上复制。

激活分片

为了更好地利用设备内存,我们经常需要标注一些内存密集型操作的输出。这样编译器就被迫只在设备上保留部分输出,而不是全部输出。在 Llama 2 中,我们显式标注了所有 torch.matmulnn.Linear 的输出。表 2 总结了相应的标注;HF 示例代码可以在此处找到。

输出名称 说明 输出形状 分区规范 (Partition Spec)
inputs_embeds 嵌入层输出 (batch_size, sequence_length, hidden_size) (data, None, model)
query_states 注意力 nn.Linear 输出 (batch_size, sequence_length, num_heads x head_dim) (data, None, model)
key_states / value_states 注意力 nn.Linear 输出 (batch_size, sequence_length, num_key_value_heads x head_dim) (data, None, model)
attn_weights 注意力权重 (batch_size, num_attention_heads, sequence_length, sequence_length) (data, model, None, None)
attn_output 注意力层输出 (batch_size, sequence_length, hidden_size) (data, None, model)
up_proj / gate_proj / down_proj MLP nn.Linear 输出 (batch_size, sequence_length, intermediate_size) (data, None, model)
logits HF 输出嵌入输出 (batch_size, sequence_length, hidden_size) (data, None, model)

表 2:SPMD 2D 分片激活分区规范

规则是,将任何输出的 batch_size 维度根据网格的 data 轴进行分片,然后复制任何输出的长度维度,最后沿 model 轴分片最后一个维度。

输入分片

对于输入分片,规则是沿网格的 data 轴对 batch 维度进行分片,并复制 sequence_length 维度。下面是示例代码,相应的 HF 更改可以在此处找到。

partition_spec = ('data', None)
sharding_spec = xs.ShardingSpec(mesh, partition_spec)
# MpDeviceLoader will shard the input data before sending to the device.
pl.MpDeviceLoader(dataloader, self.args.device, input_sharding=sharding_spec, ...)

现在,所有需要分片的数据和模型张量都已涵盖!

优化器状态与梯度

您可能想知道是否也需要对优化器状态和梯度进行分片。好消息是:XLA 编译器的分片传播功能在这两种情况下可以自动完成分片标注,无需额外提示即可提升性能。

重要的是要注意,优化器状态通常在训练循环的第一次迭代中初始化。从 XLA 编译器的角度来看,优化器状态是第一个计算图的输出,因此会传播其分片标注。对于后续迭代,优化器状态成为第二个计算图的输入,其分片标注从第一个计算图传播而来。这也是为什么 PyTorch/XLA 通常会为训练循环生成两个计算图。如果优化器状态在第一次迭代之前以某种方式初始化,用户将需要手动标注它们,就像模型权重一样。

再一次,上述分片标注的所有具体示例都可以在我们的 HF Transformers 分支中找到。该仓库还包含我们实验性功能MultiSlice 的代码,包括 HybridMeshdcn 轴,它们遵循与上述相同的原则。

注意事项

在使用 SPMD 进行训练时,有一些重要事项需要注意:

  • 使用 torch.einsum 而非 torch.matmultorch.matmul 通常会展平张量并在最后执行 torch.mm,这对于 SPMD 来说很不利,尤其是当组合的轴被分片时。XLA 编译器将很难确定如何传播分片。
  • PyTorch/XLA 提供了修补过的 [nn.Linear](https://github.com/pytorch/xla/blob/master/torch_xla/experimental/xla_sharding.py#L570) 来克服上述限制
import torch_xla.experimental.xla_sharding as xs
from torch_xla.distributed.fsdp.utils import apply_xla_patch_to_nn_linear

 model = apply_xla_patch_to_nn_linear(model, xs.xla_patched_nn_linear_forward)
  • 始终在所有分片中重复使用相同的网格
  • 始终指定 --dataloader_drop_last yes。最后较小批次的数据难以标注。
  • 在主机上初始化的大型模型可能会导致主机端 OOM(内存溢出)。避免此问题的一种方法是在meta device 上初始化参数,然后逐层创建并分片实际张量。

基础设施改进

除了上述建模技术,我们还开发了额外的功能和改进来最大化性能,包括:

  • 我们启用了异步集合通信。这需要在 XLA 编译器的延迟隐藏调度器上进行增强,以更好地优化 Llama 2 的 PyTorch 代码。
  • 我们现在允许在 IR 图的中间进行分片标注,就像 JAX 的 jax.lax.with_sharding_constraint 一样。以前只对图的输入进行标注。
  • 我们将复制的分片规范从编译器传播到图的输出。这使得我们能够自动分片优化器状态。

推理优化

为 Llama 推理实现的 PyTorch/XLA 所有优化也适用于 Llama 2。这包括使用 torch-xla 集合操作的张量并行 + Dynamo (torch.compile)改进自回归解码逻辑以避免重新编译bucketized prompt 长度、以及带有编译友好索引操作的 KV 缓存。Llama 2 引入了两个新变化:分组查询注意力 (Grouped Query Attention) 和当所有 prompt 都达到 eos 时提前停止 (Early Stopping)。我们应用了相应的更改,以提升 PyTorch/XLA 的性能和灵活性。

分组查询注意力

Llama 2 为 70B 模型启用了分组查询注意力。它允许 Key 和 Value 头的数量小于 Query 头的数量,同时仍支持 KV 缓存分片,最多可达到 KV 头的数量。对于 70B 模型,n_kv_heads 为 8,这将张量并行度限制在小于或等于 8。为了分片模型 checkpoint 并在更多设备上运行,K、V 投影权重需要先复制,然后分割成多个部分。例如,要将 70B 模型 checkpoint 从 8 个部分分片到 16 个部分,K、V 投影权重将被复制并分割成每个分片 2 个部分。我们提供了一个 reshard_checkpoints.py 脚本来处理这一问题,并确保分片后的 checkpoint 在数学上与原始 checkpoint 的表现一致。

EOS 提前停止

Llama 2 的生成代码添加了提前停止逻辑。一个名为 eos_reached 的张量用于跟踪所有 prompt 生成的完成情况,如果 batch 中所有 prompt 都到达了 eos token,生成就会提前停止。PyTorch/XLA 优化版本也包含了类似的更改,并进行了一些微调。

在 PyTorch/XLA 中,将 eos_reached 等张量的值作为控制流条件的一部分进行检查,会触发阻塞性的设备到主机传输。张量会从设备内存传输到 CPU 内存以评估其值,而所有其他逻辑都会等待。这在每次生成新 token 后都会引入毫秒级的延迟。作为权衡,我们将检查 eos_reached 值的频率降低到每生成 10 个新 token 检查一次。通过此更改,阻塞性的设备到主机传输的影响将减少 10 倍,同时提前停止仍然有效,并且在每个序列到达 eos token 后最多只会生成 9 个不必要的 token。

模型服务

PyTorch/XLA 正在研究一种服务策略,使 PyTorch 社区能够通过 Torch.ExportStableHLOSavedModel 来提供深度学习应用服务。PyTorch/XLA Serving 是 PyTorch/XLA 2.1 版本中的一个实验性功能;有关详细信息,请访问我们的服务用户指南。用户可以利用 TorchServe 来运行其单主机工作负载。

性能基准

衡量指标

为了衡量训练性能,我们使用行业标准指标:模型 FLOPS 利用率 (MFU)。模型 FLOPS 是执行单次前向和后向传播所需的浮点运算。模型 FLOPs 与硬件和实现无关,仅取决于底层模型。MFU 衡量模型在训练期间使用实际硬件的效率。实现 100% MFU 意味着模型完美地利用了硬件。

为了衡量推理性能,我们使用行业标准指标:吞吐量。首先,我们在模型编译并加载后测量每个 token 的延迟。然后,我们通过将批量大小 (BS) 除以每个芯片的延迟来计算吞吐量。因此,吞吐量衡量模型在生产环境中的表现,无论使用了多少芯片。

结果

训练评估

图 1 显示了 Llama 2 SPMD 2D 分片在各种 Google TPU v4 硬件上的训练结果,以 PyTorch/XLA FSDP 作为基准。与在相同硬件配置上运行的 FSDP 相比,我们在所有尺寸的 Llama 2 模型上将 MFU 提高了 28%。这种性能提升主要得益于:1) 2D 分片比 FSDP 的通信开销更小,以及 2) SPMD 中启用了异步集合通信,这允许通信和计算重叠。另请注意,随着模型规模的扩大,我们保持了高 MFU。表 3 显示了训练性能基准中使用的所有硬件配置以及一些超参数。

Figure 1. Llama 2 Training MFU on TPU v4 Hardware

图 1:Llama 2 在 TPU v4 硬件上的训练 MFU

图 1 中的结果是使用序列长度为 1,024 生成的。图 2 显示了性能在更大序列长度下的表现。它表明我们的性能也随序列长度线性扩展。由于在 2D 分片中,序列长度轴未被分片,需要减小每个设备的批量大小以适应更大序列长度带来的额外内存压力,因此 MFU 预计会略有下降。而且 TPU 对批量大小非常敏感。对于 Llama 2 70B 参数模型,性能下降幅度低至 4%。在准备这些结果时,Hugging Face Llama 2 分词器将最大模型输入限制为 2,048,这使得我们无法评估更大的序列长度。

Figure 2. Llama 2 SPMD Training MFU on TPU v4 with Different Sequence Lengths

图 2:Llama 2 SPMD 在不同序列长度下于 TPU v4 上的训练 MFU

模型尺寸 7B 13B 70B
TPU 核心数 V4-32 V4-64 V4-256
网格形状 (16, 1) (32, 1) (32, 4)
序列长度 1,024 2,048 1,024 2,048 1,024 2,048
全局批量大小 256 128 256 128 512 256
每设备批量大小 16 8 8 4 16 8

表 3:Llama 2 SPMD 训练性能基准 TPU 配置和超参数

最后要指出的是,我们使用 adafactor 作为优化器以获得更好的内存利用率。再一次强调,此处是用于复现上述性能基准结果的用户指南。

推理评估

在本节中,我们扩展了之前在 Cloud v4 TPU 上对 Llama 的评估。在这里,我们展示了 TPU v5e 在推理应用中的性能特性。

我们将推理吞吐量定义为模型每秒每个 TPU 芯片产生的 token 数量。图 3 显示了 Llama 2 70B 模型在 v5e-16 TPU 节点上的吞吐量。考虑到 Llama 是一个内存受限的应用,我们发现应用仅权重量化使得模型批量大小可以扩展到 32。在更大的 TPU v5e 硬件上可以获得更高的吞吐量结果,直至芯片间的 ICI 网络带宽限制了 TPU 切片实现更高的吞吐量。探索 TPU v5e 在 Llama 2 上的上限不在本文工作范围内。请注意,为了让 Llama 2 70B 模型在 v5e-16 上运行,我们复制了注意力头,使每个芯片有一个头,如上文推理部分所述。如前所述,随着模型批量大小的增加,每个 token 的延迟按比例增长;量化通过减少内存 I/O 需求来改善总体延迟。

Figure 3. Llama 2 70B Inference Per-Chip Throughput on TPU v5e vs. Batch Size

图 3:Llama 2 70B 模型在 TPU v5e 上的单芯片推理吞吐量对比批量大小

图 4 显示了不同模型尺寸下的推理吞吐量结果。这些结果突出了在给定硬件配置和使用 bf16 精度时能达到的最大吞吐量。通过仅权重量化,70B 模型上的吞吐量达到了 42。如上所述,增加硬件资源可能会带来性能提升。

Figure 4. Llama 2 Inference Per-Chip Throughput on TPU v5e

图 4:Llama 2 在 TPU v5e 上的单芯片推理吞吐量

图 5 显示了在 Cloud TPU v5e 上提供 Llama 2 模型服务(来自图 4)的成本。我们根据 us-west4 区域的 3 年承诺(预留)价格报告了 TPU v5e 的单芯片成本。所有模型尺寸均使用最大序列长度 2,048 和最大生成长度 1,000 token。请注意,通过量化,70B 模型的成本降至每 1,000 token $0.0036

Figure 5. Llama 2 Inference Per-Chip Cost on TPU v5e

图 5:Llama 2 在 TPU v5e 上的单芯片推理成本

图 6 总结了我们在 TPU v5e 上获得的最佳 Llama 2 推理延迟结果。Llama 2 7B 的结果来自我们的非量化配置(BF16 权重,BF16 激活),而 13B 和 70B 的结果来自量化配置(INT8 权重,BF16 激活)。我们将这一观察结果归因于量化固有的内存节省与计算开销之间的权衡;因此,对于较小的模型,量化可能不会导致更低的推理延迟。

此外,prompt 长度对 LLM 的内存需求有很大影响。例如,当在 v5e-4 上运行 Llama2 7B,max_seq_len=256,批量大小为 1 且未量化时,我们观察到延迟为 1.2ms / token(即 201 token / 秒 / 芯片)。

Figure 6. Llama 2 Inference Latency on TPU v5e

图 6:Llama 2 在 TPU v5e 上的推理延迟

总结思考

近期的 AI 创新浪潮具有变革意义,其中 LLM 的突破居于前沿。Meta 的 Llama 和 Llama 2 模型是这一进步浪潮中的重要里程碑。PyTorch/XLA 在 Cloud TPU(包括新的 Cloud TPU v5e)上为 Llama 2 及其他 LLM 和生成式 AI 模型提供了独特的高性能、高成本效益的训练和推理能力。展望未来,PyTorch/XLA 将继续在吞吐量和可扩展性方面突破 Cloud TPU 的性能极限,同时保持相同的 PyTorch 用户体验。

我们对 PyTorch/XLA 的未来感到无比兴奋,并邀请社区加入我们。PyTorch/XLA 是完全开源开发的。因此,请在 GitHub 上提交 issue、发送 pull request 和 RFC,以便我们公开协作。您也可以在包括 TPU 和 GPU 在内的各种 XLA 设备上亲自试用 PyTorch/XLA。

我们特别感谢 Marcello Maggioni、Tongfei Guo、Andy Davis、Berkin Ilbeyi 在此项工作中的支持和协作。

致敬,
Google 的 PyTorch/XLA 团队