PyTorch/XLA SPMD 高级主题¶
本文档将介绍 GSPMD 的一些高级主题。在继续阅读本文档之前,请先阅读SPMD 用户指南。
分片感知的宿主机到设备数据加载¶
PyTorch/XLA SPMD 接收一个单设备程序,对其进行分片并并行执行。SPMD 执行需要使用原生的 PyTorch DataLoader,它同步地将数据从宿主机传输到 XLA 设备。这会在每一步的输入数据传输期间阻塞训练。为了提高原生数据加载性能,我们让 PyTorch/XLA ParallelLoader 直接支持输入分片(src),当传递可选关键字参数 _input_sharding_ 时。
# MpDeviceLoader returns ParallelLoader.per_device_loader as iterator
train_loader = pl.MpDeviceLoader(
train_loader, # wraps PyTorch DataLoader
device,
# assume 4d input and we want to shard at the batch dimension.
input_sharding=xs.ShardingSpec(input_mesh, ('data', None, None, None)))
如果批处理的每个元素形状不同,也可以为其指定不同的 input_sharding
。
# if batch = next(train_loader) looks like
# {'x': <tensor of shape [s1, s2, s3, s4]>, 'y': <tensor for shape [s1, s2]>}
# MpDeviceLoader returns ParallelLoader.per_device_loader as iterator
train_loader = pl.MpDeviceLoader(
train_loader, # wraps PyTorch DataLoader
device,
# specify different sharding for each input of the batch.
input_sharding={
'x': xs.ShardingSpec(input_mesh, ('data', None, None, None)),
'y': xs.ShardingSpec(input_mesh, ('data', None))
}
)
虚拟设备优化¶
PyTorch/XLA 通常在张量定义后异步地将张量数据从宿主机传输到设备。这是为了让数据传输与图跟踪时间重叠。然而,由于 GSPMD 允许用户在张量定义_之后_修改张量分片,我们需要一种优化来防止张量数据在宿主机和设备之间不必要的来回传输。我们引入了虚拟设备优化,这是一种技术,首先将张量数据放置在虚拟设备 SPMD:0 上,然后在所有分片决策最终确定后,再上传到物理设备。在 SPMD 模式下,每个张量数据都放置在虚拟设备 SPMD:0 上。虚拟设备作为 XLA 设备 XLA:0 暴露给用户,其实际分片位于物理设备上,例如 TPU:0、TPU:1 等。
混合网格¶
网格很好地抽象了物理设备网格的构建方式。用户可以使用逻辑网格以任何形状和顺序排列设备。然而,可以根据物理拓扑定义性能更好的网格,尤其是在涉及数据中心网络 (DCN) 跨切片连接时。HybridMesh 创建了一个网格,在这种多切片环境中提供了良好的开箱即用性能。它接受 ici_mesh_shape 和 dcn_mesh_shape,它们表示内部和外部网络的逻辑网格形状。
from torch_xla.distributed.spmd import HybridMesh
# This example is assuming 2 slices of v4-8.
# - ici_mesh_shape: shape of the logical mesh for inner connected devices.
# - dcn_mesh_shape: shape of logical mesh for outer connected devices.
ici_mesh_shape = (1, 4, 1) # (data, fsdp, tensor)
dcn_mesh_shape = (2, 1, 1)
mesh = HybridMesh(ici_mesh_shape, dcn_mesh_shape, ('data','fsdp','tensor'))
print(mesh.shape())
>> OrderedDict([('data', 2), ('fsdp', 4), ('tensor', 1)])
在 TPU Pod 上运行 SPMD¶
如果你的网格和分区规范是基于设备数量而不是一些硬编码常量构建的,那么从单个 TPU 宿主机迁移到 TPU Pod 不需要修改代码。要在 TPU Pod 上运行 PyTorch/XLA 工作负载,请参考我们的 PJRT 指南中的Pod 部分。
XLAShardedTensor¶
xs.mark_sharding
是一个 inplace 操作,它会将分片注释附加到输入张量,但它也返回一个 XLAShardedTensor
Python 对象。
XLAShardedTensor
[RFC] 的主要用例是用分片规范注解原生的 torch.tensor
(在单个设备上)。注解立即发生,但张量的实际分片是延迟的,因为计算是惰性执行的,但输入张量会立即分片。一旦张量被注解并包装在 XLAShardedTensor
中,它可以作为 torch.Tensor
传递给现有的 PyTorch ops 和 nn.Module
层。这很重要,以确保相同的 PyTorch 层和张量 ops 可以与 XLAShardedTensor
一起使用。这意味着用户不需要重写现有的 ops 和模型代码来实现分片计算。具体来说,XLAShardedTensor
将满足以下要求
XLAShardedTensor
是torch.Tensor
的一个子类,可以直接与原生 torch ops 和module.layers
一起工作。我们使用__torch_dispatch__
将XLAShardedTensor
发送到 XLA 后端。PyTorch/XLA 检索附加的分片注解以跟踪图并调用 XLA SPMDPartitioner。在内部,
XLAShardedTensor
(及其 global_tensor 输入)由XLATensor
支持,后者具有一个特殊的数据结构,用于保存对分片设备数据的引用。惰性执行后的分片张量在宿主机上请求时(例如,打印 global_tensor 的值)可以被收集并具体化回宿主机作为 global_tensor。
本地分片的句柄在惰性执行后严格具体化。
XLAShardedTensor
暴露了 local_shards 以列表List[[XLAShard](https://github.com/pytorch/xla/blob/4e8e5511555073ce8b6d1a436bf808c9333dcac6/torch_xla/distributed/spmd/xla_sharded_tensor.py#L12)]
的形式返回可寻址设备上的本地分片。
目前也在努力将 XLAShardedTensor
集成到 DistributedTensor
API 中以支持 XLA 后端 [RFC]。
DTensor 集成¶
PyTorch 在 2.1 版本中原型发布了 DTensor。我们将 PyTorch/XLA SPMD 集成到 DTensor API RFC 中。我们有一个 distribute_tensor
的概念验证集成,它调用 mark_sharding
注解 API,使用 XLA 对张量及其计算进行分片。
import torch
from torch.distributed import DeviceMesh, Shard, distribute_tensor
# distribute_tensor now works with `xla` backend using PyTorch/XLA SPMD.
mesh = DeviceMesh("xla", list(range(world_size)))
big_tensor = torch.randn(100000, 88)
my_dtensor = distribute_tensor(big_tensor, mesh, [Shard(0)])
此功能为实验性功能,请持续关注后续版本中的更多更新、示例和教程。
torch.compile 的激活分片¶
在 2.3 版本中,PyTorch/XLA 添加了自定义 op dynamo_mark_sharding
,可用于在 torch.compile
区域中执行激活分片。这是我们正在努力的一部分,旨在使 torch.compile
+ GSPMD
成为使用 PyTorch/XLA 进行模型推理的推荐方式。此自定义 op 的使用示例
# Activation output sharding
device_ids = [i for i in range(self.num_devices)] # List[int]
mesh_shape = [self.num_devices//2, 1, 2] # List[int]
axis_names = "('data', 'model')" # string version of axis_names
partition_spec = "('data', 'model')" # string version of partition spec
torch.ops.xla.dynamo_mark_sharding(output, device_ids, mesh_shape, axis_names, partition_spec)
SPMD 调试工具¶
我们为 PyTorch/XLA SPMD 用户提供了 分片 放置 可视化 调试 工具
,支持单宿主机/多宿主机的 TPU/GPU/CPU:你可以使用 visualize_tensor_sharding
可视化分片张量,或者使用 visualize_sharding
可视化分片字符串。以下是在 TPU 单宿主机 (v4-8) 上使用 visualize_tensor_sharding
或 visualize_sharding
的两个代码示例
使用
visualize_tensor_sharding
的代码片段和可视化结果
import rich
# Here, mesh is a 2x2 mesh with axes 'x' and 'y'
t = torch.randn(8, 4, device='xla')
xs.mark_sharding(t, mesh, ('x', 'y'))
# A tensor's sharding can be visualized using the `visualize_tensor_sharding` method
from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding
generated_table = visualize_tensor_sharding(t, use_color=False)

使用
visualize_sharding
的代码片段和可视化结果
from torch_xla.distributed.spmd.debugging import visualize_sharding
sharding = '{devices=[2,2]0,1,2,3}'
generated_table = visualize_sharding(sharding, use_color=False)

你可以在 TPU/GPU/CPU 单宿主机上使用这些示例,并修改它们以在多宿主机上运行。你还可以修改它们以适应 tiled
、partial_replication
和 replicated
等分片样式。
自动分片¶
我们正在推出一个新的 PyTorch/XLA SPMD 功能,称为 自动分片
[RFC]。这是 r2.3
和 nightly
版本中的实验性功能,支持 XLA:TPU
和单个 TPUVM 宿主机。
PyTorch/XLA 自动分片可以通过以下方式之一启用
设置环境变量
XLA_AUTO_SPMD=1
在代码开头调用 SPMD API
import torch_xla.runtime as xr
xr.use_spmd(auto=True)
调用
pytorch.distributed._tensor.distribute_module
并使用auto-policy
和xla
import torch_xla.runtime as xr
from torch.distributed._tensor import DeviceMesh, distribute_module
from torch_xla.distributed.spmd import auto_policy
device_count = xr.global_runtime_device_count()
device_mesh = DeviceMesh("xla", list(range(device_count)))
# Currently, model should be loaded to xla device via distribute_module.
model = MyModule() # nn.module
sharded_model = distribute_module(model, device_mesh, auto_policy)
(可选)可以设置以下选项/环境变量来控制基于 XLA 的自动分片过程的行为
XLA_AUTO_USE_GROUP_SHARDING
:参数的组重新分片。默认设置。XLA_AUTO_SPMD_MESH
:用于自动分片的逻辑网格形状。例如,XLA_AUTO_SPMD_MESH=2,2
对应于一个包含 4 个全局设备的 2x2 网格。如果未设置,将使用默认设备网格形状num_devices,1
。