作者:Jiewen Tan, Jon Bolin, Yeounoh Chung, Liyang Lu, Siyuan Liu, Wonjoo Lee, Manfei Bai, Meghan Cowan, Jack Cao, Milad Mohammadi, Shauheen Zahirazami, Alex Spiridonov

在人工智能创新以前所未有的速度加速发展的背景下,Meta 的 Llama 系列开源大型语言模型 (LLM) 是一项引人注目的突破。Llama 标志着 LLM 向前迈出了重要一步,展示了预训练架构在广泛应用中的强大功能。Llama 2 进一步突破了规模和能力的界限,激发了语言理解、生成及其他领域的进步。

在 Llama 发布后不久,我们发布了一篇 博客文章,展示了在 Cloud TPU v4 上使用 PyTorch/XLA 实现 Llama 的超低推理延迟。基于这些结果,今天,我们很自豪地分享在 Cloud TPU v4 和我们最新的 AI 超级计算机 Cloud TPU v5e 上使用 PyTorch/XLA 进行 Llama 2 模型训练和推理的性能表现。

在这篇博文中,我们以 Llama 2 模型为例,展示 PyTorch/XLA 在 Cloud TPU 上进行 LLM 训练和推理的强大功能。我们讨论了用于提高推理吞吐量和训练模型 FLOPs 利用率 (MFU) 的计算技术和优化方法。对于 Llama 2 70B 参数模型,我们在 Google Cloud TPU 上使用 PyTorch/XLA 实现了 53% 的训练 MFU、17 毫秒/token 的推理延迟以及 42 tokens/s/芯片的吞吐量。 我们提供了 训练用户指南推理用户指南,用于重现本文中的结果。此外,您可以在 此处观看我们的 Google Next 2023 演讲

模型概述

Llama 2 模型有多种尺寸,从 7B 到 70B 参数不等,以满足不同的需求、计算资源以及训练/推理预算。无论是小型项目还是大型部署,Llama 模型都提供了多功能性和可扩展性,以适应广泛的应用。

Llama 2 是一种自回归语言模型,它使用优化的 Transformer 架构。最大的 70B 模型使用了分组查询注意力机制,这在不牺牲质量的情况下加快了推理速度。Llama 2 模型在 2 万亿个 token 上进行了训练(比 Llama 多 40% 的数据),并且推理的上下文长度为 4,096 个 token(是 Llama 上下文长度的两倍),这提高了模型的准确性、流畅性和创造性。

Llama 2 是一种最先进的 LLM,在许多基准测试中优于许多其他开源语言模型,包括推理、编码、熟练度和知识测试。模型的规模和复杂性对 AI 加速器提出了许多要求,使其成为 PyTorch/XLA 在 Cloud TPU 上进行 LLM 训练和推理性能的理想基准。

LLM 的性能挑战

针对 Llama 2 等 LLM 的大规模分布式训练引入了技术挑战,这些挑战需要实用的解决方案才能最有效地利用 TPU。Llama 的规模可能会给 TPU 的内存和处理资源带来压力。为了解决这个问题,我们使用了模型分片,即将模型分解成更小的片段,每个片段都适合单个 TPU 核心的容量。这实现了跨多个 TPU 的并行性,提高了训练速度,同时减少了通信开销。

另一个挑战是有效地管理训练 Llama 2 所需的大型数据集,这需要有效的数据分发和同步方法。此外,优化学习率调度、梯度聚合和跨分布式 TPU 的权重同步等因素对于实现收敛至关重要。

在预训练或微调 Llama 2 后,在模型检查点上运行推理会产生额外的技术挑战。我们在 之前的博客文章 中讨论的所有挑战,例如自回归解码、可变输入提示长度以及模型分片和量化的需求,仍然适用于 Llama 2。此外,Llama 2 还引入了两项新功能:分组查询注意力机制和提前停止。我们将讨论 PyTorch/XLA 如何处理这些挑战,从而在 Cloud TPU v4 和 v5e 上实现 Llama 2 的高性能、高性价比的训练和推理。

大规模分布式训练

PyTorch/XLA 提供了两种主要的进行大规模分布式训练的方式:SPMD,它利用 XLA 编译器将单设备程序转换和分区为多设备分布式程序;以及 FSDP,它实现了广泛采用的 Fully Sharded Data Parallel 算法。

在这篇博文中,我们将展示如何使用 SPMD API 注释 HuggingFace (HF) Llama 2 的实现,以最大化性能。为了进行比较,我们还展示了使用相同配置的 FSDP 结果;请阅读 此处 了解 PyTorch/XLA FSDP API。

SPMD 概述

让我们简要回顾一下 SPMD 的基本原理。有关详细信息,请参阅我们的 博客文章用户指南

网格 (Mesh)

一个多维数组,描述了 TPU 设备的逻辑拓扑结构

# Assuming you are running on a TPU host that has 8 devices attached
num_devices = xr.global_runtime_device_count()
# mesh shape will be (4,2) in this example
mesh_shape = (num_devices // 2, 2)
device_ids = np.array(range(num_devices))
# axis_names 'x' and 'y' are optional
mesh = Mesh(device_ids, mesh_shape, ('x', 'y'))

分区规范 (Partition Spec)

一个元组,描述了如何跨网格对相应张量的维度进行分片

partition_spec = ('x', 'y')

标记分片 (Mark Sharding)

一个 API,它接受网格和分区规范,然后为 XLA 编译器生成分片注释。

tensor = torch.randn(4, 4).to('xla')
# Let's resue the above mesh and partition_spec.
# It means the tensor's 0th dim is sharded 4 way and 1th dim is sharded 2 way.
xs.mark_sharding(tensor, mesh, partition_spec)

使用 SPMD 进行 2D 分片

在我们的 SPMD 博客文章 中,我们演示了使用 1D FSDP 风格的分片。在这里,我们介绍一种更强大的分片策略,称为 2D 分片,其中参数和激活都被分片。这种新的分片策略不仅允许容纳更大的模型,而且还将 MFU 提高到高达 54.3%。有关更多详细信息,请阅读基准测试部分。

本节介绍一组适用于大多数 LLM 的通用规则,为了方便起见,我们直接引用了 HF Llama 中的变量名和配置名称。

首先,让我们创建一个带有相应轴名称的 2D 网格:数据 (data) 和模型 (model)。数据轴通常是我们分发输入数据的地方,模型轴是我们进一步分发模型的地方。

mesh = Mesh(device_ids, mesh_shape, ('data', 'model'))

mesh_shape 可以是一个超参数,可以针对不同的模型大小和硬件配置进行调整。相同的网格将在所有后续的分片注释中重复使用。在接下来的几节中,我们将介绍如何使用网格对参数、激活和输入数据进行分片。

参数分片

下表总结了 HF Llama 2 的所有参数以及相应的分区规范。示例 HF 代码可以在 此处 找到。

参数名称 解释 参数形状 分区规范 (Partition Spec)
embed_tokens 嵌入层 (vocab_size, hidden_size) (model, data)
q_proj 注意力权重 (num_heads x head_dim, hidden_size) (data, model)
k_proj / v_proj 注意力权重 (num_key_value_heads x head_dim, hidden_size) (data, model)
o_proj 注意力权重 (hidden_size, num_heads x head_dim) (model, data)
gate_proj / up_proj MLP 权重 (intermediate_size, hidden_size) (model, data)
down_proj MLP 权重 (hidden_size, intermediate_size) (data, model)
lm_head HF 输出嵌入 (vocab_size, hidden_size) (model, data)

表 1:SPMD 2D 分片参数分区规范

规则是根据网格的 data 轴对除 QKVO 投影外的任何权重的 hidden_size 维度进行分片,然后使用剩余的 model 轴对另一个维度进行分片。对于 QKVO,则相反。这种模型-数据轴旋转方法类似于 Megatron-LM 的方法,旨在减少通信开销。对于 layernorm 权重,我们隐式地将它们标记为在不同设备之间复制,因为它们是 1D 张量。

激活分片

为了更好地利用设备内存,我们通常需要注释一些内存受限操作的输出。这样,编译器就会强制设备上只保留部分输出,而不是完整输出。在 Llama 2 中,我们显式地注释了所有 torch.matmulnn.Linear 的输出。表 2 总结了相应的注释;示例 HF 代码可以在 此处 找到。

输出名称 解释 输出形状 分区规范 (Partition Spec)
inputs_embeds 嵌入层输出 (batch_size, sequence_length, hidden_size) (data, None, model)
query_states 注意力 nn.Linear 输出 (batch_size, sequence_length, num_heads x head_dim) (data, None, model)
key_states / value_states 注意力 nn.Linear 输出 (batch_size, sequence_length, num_key_value_heads x head_dim) (data, None, model)
attn_weights 注意力权重 (batch_size, num_attention_heads, sequence_length, sequence_length) (data, model, None, None)
attn_output 注意力层输出 (batch_size, sequence_length, hidden_size) (data, None, model)
up_proj / gate_proj / down_proj MLP nn.Linear 输出 (batch_size, sequence_length, intermediate_size) (data, None, model)
logits HF 输出嵌入输出 (batch_size, sequence_length, hidden_size) (data, None, model)

表 2:SPMD 2D 分片激活分区规范

规则是根据网格的 data 轴对任何输出的 batch_size 维度进行分片,然后复制任何输出的长度维度,最后沿着 model 轴对最后一个维度进行分片。

输入分片

对于输入分片,规则是沿着网格的 data 轴对批次维度进行分片,并复制 sequence_length 维度。以下是示例代码,相应的 HF 更改可以在 此处 找到。

partition_spec = ('data', None)
sharding_spec = xs.ShardingSpec(mesh, partition_spec)
# MpDeviceLoader will shard the input data before sending to the device.
pl.MpDeviceLoader(dataloader, self.args.device, input_sharding=sharding_spec, ...)

现在,所有需要分片的数据和模型张量都已涵盖!

优化器状态和梯度

您可能想知道是否也需要对优化器状态和梯度进行分片。好消息:XLA 编译器的分片传播功能可以自动完成这两种情况下的分片注释,而无需更多提示来提高性能。

重要的是要注意,优化器状态通常在训练循环的第一次迭代中初始化。从 XLA 编译器的角度来看,优化器状态是第一个图的输出,因此具有传播的分片注释。对于后续迭代,优化器状态成为第二个图的输入,其分片注释从第一个图传播而来。这也是 PyTorch/XLA 通常为训练循环生成两个图的原因。如果优化器状态在第一次迭代之前以某种方式初始化,则用户将必须手动注释它们,就像模型权重一样。

同样,以上所有分片注释的具体示例都可以在我们的 HF Transformers 分支 此处 找到。该仓库还包含我们实验性功能 MultiSlice 的代码,包括 HybridMeshdcn 轴,它们遵循上述相同的原则。

注意事项

在使用 SPMD 进行训练时,有几点需要注意:

  • 使用 torch.einsum 而不是 torch.matmultorch.matmul 通常会展平张量并在最后执行 torch.mm,当组合轴被分片时,这对 SPMD 不利。XLA 编译器将很难确定如何传播分片。
  • PyTorch/XLA 提供了修补后的 [nn.Linear](https://github.com/pytorch/xla/blob/master/torch_xla/experimental/xla_sharding.py#L570) 来克服上述限制
import torch_xla.experimental.xla_sharding as xs
from torch_xla.distributed.fsdp.utils import apply_xla_patch_to_nn_linear

 model = apply_xla_patch_to_nn_linear(model, xs.xla_patched_nn_linear_forward)
  • 始终在所有分片中重复使用相同的网格
  • 始终指定 --dataloader_drop_last yes。最后一个较小的数据很难注释。
  • 在主机上初始化的大型模型可能会导致主机端 OOM。避免此问题的一种方法是在 meta 设备 上初始化参数,然后逐层创建和分片真实张量。

基础设施改进

除了上述建模技术外,我们还开发了其他功能和改进来最大化性能,包括:

  • 我们启用了异步集体通信。这需要在 XLA 编译器的延迟隐藏调度器上进行增强,以便更好地优化 Llama 2 PyTorch 代码。
  • 我们现在允许在 IR 图的中间进行分片注释,就像 JAX 的 jax.lax.with_sharding_constraint 一样。以前,只有图输入被注释。
  • 我们还从编译器向图输出传播复制的分片规范。这使我们能够自动分片优化器状态。

推理优化

为 Llama 推理实现的所有 PyTorch/XLA 优化 也适用于 Llama 2。这包括 张量并行 + Dynamo (torch.compile),使用 torch-xla 集体运算改进自回归解码逻辑以避免重新编译分桶提示长度带有编译友好索引操作的 KV 缓存。Llama 2 引入了两项新更改:分组查询注意力机制,以及当所有提示都达到 eos 时提前停止。我们应用了相应的更改,以利用 PyTorch/XLA 提升更好的性能和灵活性。

分组查询注意力机制

Llama 2 为 70B 模型启用了 分组查询注意力机制。它允许 Key 和 Value 头部的数量小于 Query 头部的数量,同时仍然支持 KV 缓存分片,最多可达 KV 头部的数量。对于 70B 模型,n_kv_heads 为 8,这限制了张量并行度小于或等于 8。为了对模型检查点进行分片以在更多设备上运行,需要先复制 K、V 投影权重,然后再将其拆分为多个部分。例如,要将 70B 模型检查点从 8 个部分分片到 16 个部分,K、V 投影权重将被复制并拆分为每个分片的 2 个部分。我们提供了一个 reshard_checkpoints.py 脚本来处理这个问题,并确保分片后的检查点在数学上与原始检查点执行结果相同。

EOS 提前停止

Llama 2 生成代码添加了 提前停止逻辑。一个 eos_reached 张量用于跟踪所有提示生成的完成情况,如果批次中所有提示都达到了 eos token,则生成将提前停止。类似的更改也已纳入 PyTorch/XLA 优化版本中,并进行了一些小的调整。

在 PyTorch/XLA 中,检查像 eos_reached 这样的张量的值作为控制流条件的一部分,会调用阻塞的设备到主机传输。张量将从设备内存传输到 CPU 内存以评估其值,而所有其他逻辑都在等待。这会在每次生成新 token 后引入毫秒级的延迟。作为一种权衡,我们将检查 eos_reached 值的频率降低到 每生成 10 个新 token 检查一次。通过此更改,阻塞的设备到主机传输的影响将减少 10 倍,而提前停止仍然有效,并且在每个序列达到 eos token 后最多会生成 9 个不必要的 token。

模型服务

PyTorch/XLA 正在开发一种服务策略,使 PyTorch 社区能够通过 Torch.ExportStableHLOSavedModel 为其深度学习应用程序提供服务。PyTorch/XLA Serving 是 PyTorch/XLA 2.1 版本 中的一项实验性功能;有关详细信息,请访问我们的 服务用户指南。用户可以利用 TorchServe 运行其单主机工作负载。

基准测试

指标

为了衡量训练性能,我们使用了行业标准指标:模型 FLOPs 利用率 (MFU)。模型 FLOPs 是执行单次前向和后向传播所需的浮点运算次数。模型 FLOPs 与硬件和实现无关,仅取决于底层模型。MFU 衡量模型在训练期间有效利用实际硬件的程度。实现 100% MFU 意味着模型完美地利用了硬件。

为了衡量推理性能,我们使用了行业标准的吞吐量指标。首先,我们测量模型编译和加载后的每个 token 的延迟。然后,我们通过将批次大小 (BS) 除以每个芯片的延迟来计算吞吐量。因此,吞吐量衡量的是模型在生产环境中的性能,而与使用了多少芯片无关。

结果

训练评估

图 1 显示了在各种 Google TPU v4 硬件上使用 PyTorch/XLA FSDP 作为基线的 Llama 2 SPMD 2D 分片训练结果。与在相同硬件配置上运行的 FSDP 相比,我们在所有尺寸的 Llama 2 模型上将 MFU 提高了 28%。这种性能提升主要归功于:1) 2D 分片比 FSDP 的通信开销更少,以及 2) SPMD 中启用了异步集体通信,这允许通信和计算重叠。另请注意,随着模型尺寸的扩大,我们保持了较高的 MFU。表 3 显示了所有硬件配置以及训练基准测试中使用的一些超参数。

Figure 1. Llama 2 Training MFU on TPU v4 Hardware

图 1:TPU v4 硬件上的 Llama 2 训练 MFU

图 1 中的结果是在序列长度为 1,024 的情况下生成的。图 2 显示了性能如何随更大的序列长度变化。它表明我们的性能也随着序列长度线性扩展。由于更大的序列长度引入了额外的内存压力,因此需要更小的每设备批次大小来适应,因为序列长度轴在 2D 分片中未被分片,因此 MFU 预计会略有下降。TPU 对批次大小非常敏感。对于 Llama 2 70B 参数模型,性能下降幅度低至 4%。在准备这些结果时,Hugging Face Llama 2 分词器 将最大模型输入限制为 2,048,这阻止了我们评估更大的序列长度。

Figure 2. Llama 2 SPMD Training MFU on TPU v4 with Different Sequence Lengths

图 2:在不同序列长度下,TPU v4 上 Llama 2 SPMD 训练 MFU

模型尺寸 7B 13B 70B
TPU NumCores V4-32 V4-64 V4-256
网格形状 (16, 1) (32, 1) (32, 4)
1,024 2,048 1,024 2,048 1,024 2,048
全局批次大小 256 128 256 128 512 256
每设备批次大小 16 8 8 4 16 8

表 3:Llama 2 SPMD 训练基准测试 TPU 配置和超参数

最后需要指出的是,我们使用 adafactor 作为优化器,以获得更好的内存利用率。再次强调,这里是 用户指南,用于重现上面列出的基准测试结果。

推理评估

在本节中,我们将扩展我们 之前对 Cloud v4 TPU 上 Llama 的评估。在这里,我们展示了 TPU v5e 在推理应用中的性能特性。

我们将推理吞吐量定义为每个 TPU 芯片每秒生成的 token 数量。图 3 显示了 v5e-16 TPU 节点上 Llama 2 70B 模型的吞吐量。鉴于 Llama 是一种内存受限的应用,我们看到应用仅权重量化可以解锁扩展模型批次大小至 32。在更大的 TPU v5e 硬件上,可能会实现更高的吞吐量结果,直至芯片之间的 ICI 网络带宽限制 TPU 切片提供更高的吞吐量。探索 TPU v5e 在 Llama 2 上的上限超出了本文的工作范围。请注意,为了使 Llama 2 70B 模型在 v5e-16 上运行,我们复制了注意力头部,使每个芯片有一个头部,如上文推理部分所述。正如 之前 讨论的那样,随着模型批次大小的增加,每个 token 的延迟会成比例增长;量化通过减少内存 I/O 需求来提高整体延迟。

Figure 3. Llama 2 70B Inference Per-Chip Throughput on TPU v5e vs. Batch Size

图 3:TPU v5e 上 Llama 2 70B 模型推理每芯片吞吐量与批次大小的关系

图 4 显示了不同模型尺寸的推理吞吐量结果。这些结果突出了在使用 bf16 精度时,给定硬件配置的最大吞吐量。通过仅权重量化,70B 模型的吞吐量达到 42。如上所述,增加硬件资源可能会带来性能提升。

Figure 4. Llama 2 Inference Per-Chip Throughput on TPU v5e

图 4:TPU v5e 上 Llama 2 推理每芯片吞吐量

图 5 显示了在 Cloud TPU v5e 上服务 Llama 2 模型(来自图 4)的成本。我们根据 us-west4 地区的 3 年承诺(预留)价格报告了 TPU v5e 每芯片的成本。所有模型尺寸均使用 2,048 的最大序列长度和 1,000 个 token 的最大生成长度。请注意,通过量化,70B 模型的成本降至 每 1,000 个 token 0.0036 美元

Figure 5. Llama 2 Inference Per-Chip Cost on TPU v5e

图 5:TPU v5e 上 Llama 2 推理每芯片成本

图 6 总结了我们在 TPU v5e 上获得的最佳 Llama 2 推理延迟结果。Llama 2 7B 模型的结果来自我们的非量化配置(BF16 权重,BF16 激活),而 13B 和 70B 模型的结果来自量化配置(INT8 权重,BF16 激活)。我们将此观察结果归因于量化的固有内存节省与计算开销之间的权衡;因此,对于较小的模型,量化可能不会降低推理延迟。

此外,提示长度对 LLM 的内存需求有很大影响。例如,我们观察到在 v5e-4 上运行 Llama2 7B 模型,当 max_seq_len=256,批次大小为 1 且未进行量化时,延迟为 1.2 毫秒/token(即 201 tokens/秒/芯片)。

Figure 6. Llama 2 Inference Latency on TPU v5e

图 6:TPU v5e 上 Llama 2 推理延迟

最终想法

最近一波人工智能创新浪潮可谓是变革性的,其中大型语言模型 (LLM) 的突破首当其冲。Meta 的 Llama 和 Llama 2 模型是这波浪潮中值得关注的里程碑。PyTorch/XLA 独特地实现了 Llama 2 和其他 LLM 以及生成式 AI 模型在 Cloud TPU(包括新的 Cloud TPU v5e)上的高性能、高性价比的训练和推理。展望未来,PyTorch/XLA 将继续在 Cloud TPU 上突破吞吐量和可扩展性的性能极限,同时保持相同的 PyTorch 用户体验。

我们对 PyTorch/XLA 的未来发展感到无比兴奋,并邀请社区加入我们。PyTorch/XLA 完全以开源方式开发。因此,请在 GitHub 上提交问题、提交拉取请求和发送 RFC,以便我们可以公开协作。您也可以亲自试用 PyTorch/XLA 在各种 XLA 设备(包括 TPU 和 GPU)上的表现。

我们要特别感谢 Marcello Maggioni、Tongfei Guo、Andy Davis、Berkin Ilbeyi 在此项工作中给予的支持和协作。

致谢,
Google PyTorch/XLA 团队