作者:Yeounoh Chung, Jon Bolin, Milad Mohammadi, Jiewen Tan, Jack Cao, Joe Spisak, Alex Spiridonov, Shauheen Zahirazami, Steven Krawczyk, Wonjoo Lee Mohit Khatwani, Wanchao Liang, Vaibhav Singh

今天,我们非常高兴地宣布 PyTorch/XLA SPMD 的发布:它将 GSPMD 集成到 PyTorch 中,并提供易于使用的 API。寻求卓越性能和扩展能力的 PyTorch 开发者可以训练和部署最大的神经网络,同时最大化利用 AI 加速器,例如 Google Cloud TPUs。

引言

GSPMD 是一个用于 ML 工作负载的自动并行化系统。XLA 编译器根据用户提供的分片提示,将单设备程序转换为带有适当集合操作(collectives)的分区程序。这使得开发者可以编写 PyTorch 程序,就像它们运行在一个大型设备上一样,无需任何自定义的分片计算和/或集合通信操作来扩展模型。

PyTorch/XLA SPMD 允许 PyTorch 用户以更少的精力、更高的性能使用 GSPMD 来并行化他们的 ML 工作负载。一些主要亮点包括:

  • 更好的开发者体验。只需用户提供少量的分片注解,PyTorch/XLA SPMD 即可实现与最高效的 PyTorch 分片实现相当的性能(参见下方的示例和结果部分)。PyTorch/XLA SPMD 将 ML 模型编程任务与并行化挑战分离开来。其自动化模型分片方法让用户无需自行实现带有适当集合操作的分片版本操作。
  • 一个单一的 API,支持多种并行算法(包括数据并行、完全分片数据并行、空间分区张量并行和流水线并行,以及这些算法的组合),适用于不同的 ML 工作负载和模型架构。
  • 大型模型训练中的行业领先性能。PyTorch/XLA SPMD 将强大的 XLA GSPMD 引入 PyTorch,使用户能够充分利用 Google Cloud TPUs 的强大能力。
  • 使 PyTorch 和 JAX 开发者能够利用相同的底层 XLA API 来扩展模型。

关键概念

分片注解 API 背后的关键概念是:1) Mesh,2) Partition Spec,以及 3) 使用 Mesh 和 Partition Spec 表达分片意图的 mark_sharding API。更详细的设计概述可作为用户指南在此处获取。

Mesh (网格)

对于给定的设备集群,物理 Mesh 是互连拓扑的表示。

我们基于这种拓扑结构派生出逻辑 Mesh,以创建可用于对模型中张量不同轴进行分区的设备子组。我们将分片注解应用于逻辑 Mesh 上的程序映射;这会自动在程序图中插入通信集合操作,以支持功能正确性(参见下图)。

SPMD on PyTorch/XLA

我们使用 Mesh API 抽象逻辑 Mesh。逻辑 Mesh 的轴可以命名。这是一个示例:

import numpy as np
import torch_xla.runtime as xr
import torch_xla.experimental.xla_sharding as xs
from torch_xla.experimental.xla_sharding import Mesh

# Enable XLA SPMD execution mode.
xr.use_spmd()

# Assuming you are running on a TPU host that has 8 devices attached
num_devices = xr.global_runtime_device_count()
# mesh shape will be (4,2) in this example
mesh_shape = (num_devices // 2, 2)
device_ids = np.array(range(num_devices))
# axis_names 'x' nad 'y' are optional
mesh = Mesh(device_ids, mesh_shape, ('x', 'y'))

mesh.get_logical_mesh()
>> array([[0, 1],
          [2, 3],
          [4, 5],
          [6, 7]])
mesh.shape()
>> OrderedDict([('x', 4), ('y', 2)])

Partition Spec (分区规范)

`partition_spec` 与输入张量具有相同的秩 (rank)。每个维度描述了相应的输入张量维度如何在设备 Mesh(逻辑上由 `mesh_shape` 定义)上进行分片。partition_spec 是一个元组 (tuple),包含 device_mesh 维度的 index、`None`,或者 Mesh 维度索引的元组。如果相应的 Mesh 维度已命名,则 index 可以是 intstr。这指定了每个输入秩如何进行分片(indexmesh_shape)或复制(None)。

# Provide optional mesh axis names and use them in the partition spec
mesh = Mesh(device_ids, (4, 2), ('data', 'model'))
partition_spec = ('model', 'data')
xs.mark_sharding(input_tensor, mesh, partition_spec)

我们支持原始 GSPMD 论文中描述的所有三种分片类型。例如,可以像这样指定部分复制:

# Provide optional mesh axis names and use them in the partition spec
mesh = Mesh(device_ids, (2, 2, 2), ('x', 'y', 'z'))

# evenly shard across x and z and replicate among y
partition_spec = ('x', 'z')  # equivalent to ('x', None, 'z')
xs.mark_sharding(input_tensor, mesh, partition_spec)

分片注解的简单示例

用户可以使用 mark_sharding API () 对原生 PyTorch 张量进行注解。该 API 接收 torch.Tensor 作为输入,并返回一个 XLAShardedTensor 作为输出。

def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh, partition_spec: Tuple[Union[int, None]]) -> XLAShardedTensor

调用 mark_sharding API 需要用户定义的逻辑 Meshpartition_spec,并为 XLA 编译器生成分片注解。分片规范会附加到 XLATensor 和原始输入张量上。以下是一个来自 [RFC] 的简单用法示例,用以说明分片注解 API 的工作方式:

import numpy as np
import torch
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
import torch_xla.experimental.xla_sharding as xs
from torch_xla.experimental.xla_sharded_tensor import XLAShardedTensor
from torch_xla.experimental.xla_sharding import Mesh

# Enable XLA SPMD execution mode.
xr.use_spmd()

# Device mesh, this and partition spec as well as the input tensor shape define the individual shard shape.
num_devices = xr.global_runtime_device_count()
mesh_shape = (2, num_devicese // 2)  # 2x4 on v3-8, 2x2 on v4-8  
device_ids = np.array(range(num_devices))
mesh = Mesh(device_ids, mesh_shape, ('x', 'y'))

t = torch.randn(8, 4).to(xm.xla_device())

# Mesh partitioning, each device holds 1/8-th of the input
partition_spec = (0, 1)
m1_sharded = xs.mark_sharding(t, mesh, partition_spec)
assert isinstance(m1_sharded, XLAShardedTensor) == True
# Note that the sharding annotation is also in-placed updated to t

我们可以在 PyTorch 程序中注解不同的张量,以启用不同的并行技术,如下方注释所示:

# Sharding annotate the linear layer weights. SimpleLinear() is a nn.Module.
model = SimpleLinear().to(xm.xla_device())
xs.mark_sharding(model.fc1.weight, mesh, partition_spec)

# Training loop
model.train()
for step, (data, target) in enumerate(loader):
  # Assumes `loader` returns data, target on XLA device
  optimizer.zero_grad()
  # Sharding annotate input data, we can shard any input
  # dimensions. Sharding the batch dimension enables 
  # data parallelism, sharding the feature dimension enables
  # spatial partitioning.
  xs.mark_sharding(data, mesh, partition_spec)
  ouput = model(data)
  loss = loss_fn(output, target)
  optimizer.step()
  xm.mark_step()

更完整的单元测试用例和集成测试示例可在 PyTorch/XLA 仓库中找到。

结果

性能

我们使用 GPT-2 模型 () 测量了 PyTorch/XLA SPMD 的性能,并与用户模式 FSDP 进行了比较。

此处,SPMD 应用了与 FSDP 图中相同的分片方案(即 1D 分片)。通过探索更高级的 SPMD 分片方案,用户有望获得更好的 MFU 结果。

SPMD vs. FSDP

我们使用模型浮点运算利用率 (MFU) 作为比较指标。MFU 是“观察到的吞吐量与系统在峰值 FLOPs 下的理论最大吞吐量之比” (PaLM 论文)。

flops_per_step = 6 * global_batch_size * seq_len * num_params
model_flops_utilization = flops_per_step / step_time(s) / chip_count / flops_per_chip

这个估算假设输入维度远大于输入序列长度 (d_model » seq_len)。如果这个假设被违反,自注意力机制 (self-attention) 的 FLOPs 开始变得足够显著,这个表达式将低估真实的 MFU。

可扩展性

SPMD 的核心优势之一是灵活的分区能力,这可以用于节省加速器内存 (HBM) 使用并提高可扩展性。对于可扩展性分析,我们进行了两项研究:1) 使用 Hugging Face transformers (GPT-2) 作为基础实现,检查了 4 种模型尺寸下的峰值 HBM 使用情况;2) 检查了使用空间分区时的峰值 HBM 使用情况。

Peak HBM Utilization

上图显示未分片的 20亿参数模型的峰值内存占用为 26GB(红色虚线)。对模型权重进行分片(模型并行)降低了峰值内存占用,因此可以在给定的 TPU pod 切片上进行更大模型的训练。在这些实验中,我们在 Google Cloud TPU v4-16 上,对于一个 40亿参数的模型,实现了高达 39.75% 的 MFU。

我们还在 Cloud TPU v4-8 上,使用空间分区和一个简单的 ResNet50 示例 () 进行了输入批次可扩展性测试。对于数据并行 (DDP, FSDP),输入批次通常在批次维度上进行分片,但 PyTorch/XLA SPMD 允许在输入特征维度上进行输入分片,以实现空间分片。如下图所示,通过空间分区可以将单设备批次大小推高至 512,这在其他数据并行技术中是不可能实现的。

Batch size scaling with spatial partitioning

PyTorch/XLA SPMD 的未来之路

我们对 PyTorch/XLA 的未来感到兴奋,并邀请社区加入我们。SPMD 仍处于实验阶段,我们不断为其添加新功能。在未来的版本中,我们计划解决异步数据加载、部分复制分片和其他改进。我们非常乐意听取您的意见,解答您关于 PyTorch/XLA SPMD 的问题,并了解您如何使用 SPMD。

祝好!

Google 的 PyTorch/XLA 团队