• 文档 >
  • PyTorch/XLA SPMD 用户指南
快捷方式

PyTorch/XLA SPMD 用户指南

在本用户指南中,我们将讨论 GSPMD 如何集成到 PyTorch/XLA 中,并提供设计概述以说明 SPMD 分片注解 API 及其构造的工作原理。

什么是 PyTorch/XLA SPMD?

GSPMD 是一种用于常见 ML 工作负载的自动并行化系统。XLA 编译器会根据用户提供的分片提示,将单设备程序转换为带有适当集体操作的分区程序。此功能允许开发者像编写运行在单个大型设备上的 PyTorch 程序一样进行编写,无需任何自定义分片计算操作和/或集体通信即可实现扩展。

Execution strategies

*图 1. 两种不同执行策略的比较:(a) 非 SPMD 和 (b) SPMD。*

如何使用 PyTorch/XLA SPMD?

以下是使用 SPMD 的一个简单示例

import numpy as np
import torch
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
import torch_xla.distributed.spmd as xs
from torch_xla.distributed.spmd import Mesh


# Enable XLA SPMD execution mode.
xr.use_spmd()


# Device mesh, this and partition spec as well as the input tensor shape define the individual shard shape.
num_devices = xr.global_runtime_device_count()
mesh_shape = (num_devices, 1)
device_ids = np.array(range(num_devices))
mesh = Mesh(device_ids, mesh_shape, ('data', 'model'))


t = torch.randn(8, 4).to(xm.xla_device())


# Mesh partitioning, each device holds 1/8-th of the input
partition_spec = ('data', 'model')
xs.mark_sharding(t, mesh, partition_spec)

我们来逐一解释这些概念

SPMD 模式

要使用 SPMD,你需要通过 xr.use_spmd() 来启用它。在 SPMD 模式下只有一个逻辑设备。分布式计算和集体操作由 mark_sharding 处理。请注意,用户不能将 SPMD 与其他分布式库混用。

Mesh

对于给定的设备集群,物理 mesh 是互连拓扑的表示。

  1. mesh_shape 是一个元组,其乘积等于物理设备的总数。

  2. device_ids 几乎总是 np.array(range(num_devices))

  3. 建议用户为每个 mesh 维度命名。在上面的例子中,第一个 mesh 维度是 data 维度,第二个 mesh 维度是 model 维度。

你也可以通过以下方式查看更多 mesh 信息

>>> mesh.shape()
OrderedDict([('data', 4), ('model', 1)])

Partition Spec

partition_spec 的秩与输入 tensor 相同。每个维度描述了对应的输入 tensor 维度如何在设备 mesh 上进行分片。在上面的例子中,tensor t 的第一个维度在 data 维度上进行分片,第二个维度在 model 维度上进行分片。

用户也可以对维度与 mesh 形状不同的 tensor 进行分片。

t1 = torch.randn(8, 8, 16).to(device)
t2 = torch.randn(8).to(device)

# First dimension is being replicated.
xs.mark_sharding(t1, mesh, (None, 'data', 'model'))

# First dimension is being sharded at data dimension.
# model dimension is used for replication when omitted.
xs.mark_sharding(t2, mesh, ('data',))

# First dimension is sharded across both mesh axes.
xs.mark_sharding( t2, mesh, (('data', 'model'),))

进一步阅读

  1. 示例 使用 SPMD 来表示数据并行。

  2. 示例 使用 SPMD 来表示 FSDP(完全分片数据并行)。

  3. SPMD 高级主题

  4. SPMD 分布式检查点

文档

查阅 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获取疑问解答

查看资源