• 文档 >
  • PyTorch/XLA SPMD 用户指南
快捷方式

PyTorch/XLA SPMD 用户指南

在本用户指南中,我们将讨论 GSPMD 如何集成到 PyTorch/XLA 中,并提供设计概述,以说明 SPMD 分片注解 API 及其构造的工作原理。

什么是 PyTorch/XLA SPMD?

GSPMD 是一个用于常见 ML 工作负载的自动并行化系统。XLA 编译器将根据用户提供的分片提示,将单设备程序转换为具有适当集合通信的分区程序。此功能允许开发者编写 PyTorch 程序,就像它们在单个大型设备上一样,而无需任何自定义分片计算操作和/或集合通信来进行扩展。

Execution strategies

*图 1. 两种不同执行策略的比较,(a) 用于非 SPMD,(b) 用于 SPMD。*

如何使用 PyTorch/XLA SPMD?

这是一个使用 SPMD 的简单示例

import numpy as np
import torch
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
import torch_xla.distributed.spmd as xs
from torch_xla.distributed.spmd import Mesh


# Enable XLA SPMD execution mode.
xr.use_spmd()


# Device mesh, this and partition spec as well as the input tensor shape define the individual shard shape.
num_devices = xr.global_runtime_device_count()
mesh_shape = (num_devices, 1)
device_ids = np.array(range(num_devices))
mesh = Mesh(device_ids, mesh_shape, ('data', 'model'))


t = torch.randn(8, 4).to(xm.xla_device())


# Mesh partitioning, each device holds 1/8-th of the input
partition_spec = ('data', 'model')
xs.mark_sharding(t, mesh, partition_spec)

让我们逐一解释这些概念

SPMD 模式

为了使用 SPMD,您需要通过 xr.use_spmd() 启用它。在 SPMD 模式下,只有一个逻辑设备。分布式计算和集合通信由 mark_sharding 处理。请注意,用户不能将 SPMD 与其他分布式库混合使用。

Mesh(网格)

对于给定的设备集群,物理网格是互连拓扑的表示。

  1. mesh_shape 是一个元组,它将被乘到物理设备的总数。

  2. device_ids 几乎总是 np.array(range(num_devices))

  3. 还鼓励用户为每个网格维度命名。在上面的示例中,第一个网格维度是 data 维度,第二个网格维度是 model 维度。

您还可以通过以下方式查看更多网格信息

>>> mesh.shape()
OrderedDict([('data', 4), ('model', 1)])

Partition Spec(分区规范)

partition_spec 与输入张量具有相同的秩。每个维度描述了相应的输入张量维度如何在设备网格上进行分片。在上面的示例中,张量 t 的第一个维度在 data 维度上进行分片,第二个维度在 model 维度上进行分片。

用户还可以对维度与网格形状不同的张量进行分片。

t1 = torch.randn(8, 8, 16).to(device)
t2 = torch.randn(8).to(device)

# First dimension is being replicated.
xs.mark_sharding(t1, mesh, (None, 'data', 'model'))

# First dimension is being sharded at data dimension.
# model dimension is used for replication when omitted.
xs.mark_sharding(t2, mesh, ('data',))

# First dimension is sharded across both mesh axes.
xs.mark_sharding( t2, mesh, (('data', 'model'),))

进一步阅读

  1. 示例,使用 SPMD 来表达数据并行性。

  2. 示例,使用 SPMD 来表达 FSDP(Fully Sharded Data Parallel,全分片数据并行)。

  3. SPMD 高级主题

  4. Spmd 分布式检查点

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得问题解答

查看资源