PyTorch/XLA SPMD 用户指南¶
在本用户指南中,我们将讨论 GSPMD 如何集成到 PyTorch/XLA 中,并提供设计概述,以说明 SPMD 分片注解 API 及其构造的工作原理。
什么是 PyTorch/XLA SPMD?¶
GSPMD 是一个用于常见 ML 工作负载的自动并行化系统。XLA 编译器将根据用户提供的分片提示,将单设备程序转换为具有适当集合通信的分区程序。此功能允许开发者编写 PyTorch 程序,就像它们在单个大型设备上一样,而无需任何自定义分片计算操作和/或集合通信来进行扩展。
data:image/s3,"s3://crabby-images/0f1d1/0f1d17c449cc53c4ddb8341ff29b80f935ade9e2" alt="Execution strategies"
*图 1. 两种不同执行策略的比较,(a) 用于非 SPMD,(b) 用于 SPMD。*
如何使用 PyTorch/XLA SPMD?¶
这是一个使用 SPMD 的简单示例
import numpy as np
import torch
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
import torch_xla.distributed.spmd as xs
from torch_xla.distributed.spmd import Mesh
# Enable XLA SPMD execution mode.
xr.use_spmd()
# Device mesh, this and partition spec as well as the input tensor shape define the individual shard shape.
num_devices = xr.global_runtime_device_count()
mesh_shape = (num_devices, 1)
device_ids = np.array(range(num_devices))
mesh = Mesh(device_ids, mesh_shape, ('data', 'model'))
t = torch.randn(8, 4).to(xm.xla_device())
# Mesh partitioning, each device holds 1/8-th of the input
partition_spec = ('data', 'model')
xs.mark_sharding(t, mesh, partition_spec)
让我们逐一解释这些概念
SPMD 模式¶
为了使用 SPMD,您需要通过 xr.use_spmd()
启用它。在 SPMD 模式下,只有一个逻辑设备。分布式计算和集合通信由 mark_sharding
处理。请注意,用户不能将 SPMD 与其他分布式库混合使用。
Mesh(网格)¶
对于给定的设备集群,物理网格是互连拓扑的表示。
mesh_shape
是一个元组,它将被乘到物理设备的总数。device_ids
几乎总是np.array(range(num_devices))
。还鼓励用户为每个网格维度命名。在上面的示例中,第一个网格维度是
data
维度,第二个网格维度是model
维度。
您还可以通过以下方式查看更多网格信息
>>> mesh.shape()
OrderedDict([('data', 4), ('model', 1)])
Partition Spec(分区规范)¶
partition_spec 与输入张量具有相同的秩。每个维度描述了相应的输入张量维度如何在设备网格上进行分片。在上面的示例中,张量 t
的第一个维度在 data
维度上进行分片,第二个维度在 model
维度上进行分片。
用户还可以对维度与网格形状不同的张量进行分片。
t1 = torch.randn(8, 8, 16).to(device)
t2 = torch.randn(8).to(device)
# First dimension is being replicated.
xs.mark_sharding(t1, mesh, (None, 'data', 'model'))
# First dimension is being sharded at data dimension.
# model dimension is used for replication when omitted.
xs.mark_sharding(t2, mesh, ('data',))
# First dimension is sharded across both mesh axes.
xs.mark_sharding( t2, mesh, (('data', 'model'),))