PyTorch/XLA SPMD 用户指南¶
在本用户指南中,我们将讨论 GSPMD 如何集成到 PyTorch/XLA 中,并提供设计概述以说明 SPMD 分片注解 API 及其构造的工作原理。
什么是 PyTorch/XLA SPMD?¶
GSPMD 是一种用于常见 ML 工作负载的自动并行化系统。XLA 编译器会根据用户提供的分片提示,将单设备程序转换为带有适当集体操作的分区程序。此功能允许开发者像编写运行在单个大型设备上的 PyTorch 程序一样进行编写,无需任何自定义分片计算操作和/或集体通信即可实现扩展。

如何使用 PyTorch/XLA SPMD?¶
以下是使用 SPMD 的一个简单示例
import numpy as np
import torch
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
import torch_xla.distributed.spmd as xs
from torch_xla.distributed.spmd import Mesh
# Enable XLA SPMD execution mode.
xr.use_spmd()
# Device mesh, this and partition spec as well as the input tensor shape define the individual shard shape.
num_devices = xr.global_runtime_device_count()
mesh_shape = (num_devices, 1)
device_ids = np.array(range(num_devices))
mesh = Mesh(device_ids, mesh_shape, ('data', 'model'))
t = torch.randn(8, 4).to(xm.xla_device())
# Mesh partitioning, each device holds 1/8-th of the input
partition_spec = ('data', 'model')
xs.mark_sharding(t, mesh, partition_spec)
我们来逐一解释这些概念
SPMD 模式¶
要使用 SPMD,你需要通过 xr.use_spmd()
来启用它。在 SPMD 模式下只有一个逻辑设备。分布式计算和集体操作由 mark_sharding
处理。请注意,用户不能将 SPMD 与其他分布式库混用。
Mesh¶
对于给定的设备集群,物理 mesh 是互连拓扑的表示。
mesh_shape
是一个元组,其乘积等于物理设备的总数。device_ids
几乎总是np.array(range(num_devices))
。建议用户为每个 mesh 维度命名。在上面的例子中,第一个 mesh 维度是
data
维度,第二个 mesh 维度是model
维度。
你也可以通过以下方式查看更多 mesh 信息
>>> mesh.shape()
OrderedDict([('data', 4), ('model', 1)])
Partition Spec¶
partition_spec 的秩与输入 tensor 相同。每个维度描述了对应的输入 tensor 维度如何在设备 mesh 上进行分片。在上面的例子中,tensor t
的第一个维度在 data
维度上进行分片,第二个维度在 model
维度上进行分片。
用户也可以对维度与 mesh 形状不同的 tensor 进行分片。
t1 = torch.randn(8, 8, 16).to(device)
t2 = torch.randn(8).to(device)
# First dimension is being replicated.
xs.mark_sharding(t1, mesh, (None, 'data', 'model'))
# First dimension is being sharded at data dimension.
# model dimension is used for replication when omitted.
xs.mark_sharding(t2, mesh, ('data',))
# First dimension is sharded across both mesh axes.
xs.mark_sharding( t2, mesh, (('data', 'model'),))