• 文档 >
  • PyTorch/XLA SPMD 高级主题
快捷方式

PyTorch/XLA SPMD 高级主题

在本文档中,我们将介绍关于 GSPMD 的一些高级主题。在继续阅读本文档之前,请先阅读 SPMD 用户指南

分片感知的主机到设备数据加载

PyTorch/XLA SPMD 采用单设备程序,对其进行分片并并行执行。SPMD 执行需要使用原生 PyTorch DataLoader,它从主机同步传输数据到 XLA 设备。这会在每一步的输入数据传输期间阻塞训练。为了提高原生数据加载性能,当传递可选的 kwarg _input_sharding_ 时,我们使 PyTorch/XLA ParallelLoader 直接支持输入分片 (src)。

# MpDeviceLoader returns ParallelLoader.per_device_loader as iterator
train_loader = pl.MpDeviceLoader(
         train_loader,  # wraps PyTorch DataLoader
         device,
     # assume 4d input and we want to shard at the batch dimension.
         input_sharding=xs.ShardingSpec(input_mesh, ('data', None, None, None)))

如果批次中每个元素的形状不同,也可以为它们指定不同的 input_sharding

# if batch = next(train_loader) looks like
# {'x': <tensor of shape [s1, s2, s3, s4]>, 'y': <tensor for shape [s1, s2]>}

# MpDeviceLoader returns ParallelLoader.per_device_loader as iterator
train_loader = pl.MpDeviceLoader(
         train_loader,  # wraps PyTorch DataLoader
         device,
     # specify different sharding for each input of the batch.
         input_sharding={
          'x': xs.ShardingSpec(input_mesh, ('data', None, None, None)),
          'y': xs.ShardingSpec(input_mesh, ('data', None))
        }
)

虚拟设备优化

PyTorch/XLA 通常在张量定义后,将张量数据从主机异步传输到设备。这是为了使数据传输与图追踪时间重叠。然而,由于 GSPMD 允许用户在张量定义_之后_修改张量分片,我们需要一种优化来防止不必要的张量数据在主机和设备之间来回传输。我们引入了虚拟设备优化,这是一种将张量数据首先放置在虚拟设备 SPMD:0 上的技术,然后在所有分片决策最终确定后上传到物理设备。SPMD 模式下的每个张量数据都放置在虚拟设备 SPMD:0 上。虚拟设备作为 XLA 设备 XLA:0 向用户公开,实际分片位于物理设备上,如 TPU:0、TPU:1 等。

混合网格

网格很好地抽象了物理设备网格的构建方式。用户可以使用逻辑网格以任何形状和顺序排列设备。然而,可以基于物理拓扑定义性能更高的网格,尤其是在涉及数据中心网络 (DCN) 跨切片连接时。HybridMesh 创建了一个网格,该网格为这种多切片环境提供了开箱即用的良好性能。它接受 ici_mesh_shape 和 dcn_mesh_shape,它们表示内部和外部网络的逻辑网格形状。

from torch_xla.distributed.spmd import HybridMesh

# This example is assuming 2 slices of v4-8.
# - ici_mesh_shape: shape of the logical mesh for inner connected devices.
# - dcn_mesh_shape: shape of logical mesh for outer connected devices.
ici_mesh_shape = (1, 4, 1) # (data, fsdp, tensor)
dcn_mesh_shape = (2, 1, 1)

mesh = HybridMesh(ici_mesh_shape, dcn_mesh_shape, ('data','fsdp','tensor'))
print(mesh.shape())
>> OrderedDict([('data', 2), ('fsdp', 4), ('tensor', 1)])

在 TPU Pod 上运行 SPMD

如果根据设备数量而不是某些硬编码常量构建网格和分区规范,则从单 TPU 主机迁移到 TPU Pod 不需要代码更改。要在 TPU Pod 上运行 PyTorch/XLA 工作负载,请参阅我们的 PJRT 指南的 Pods 部分

XLAShardedTensor

xs.mark_sharding 是一个原地操作,它会将分片注释附加到输入张量,但它也会返回一个 XLAShardedTensor python 对象。

XLAShardedTensor 的主要用例 [RFC] 是使用分片规范注释原生 torch.tensor(在单个设备上)。注释立即发生,但张量的实际分片会延迟,因为计算是惰性执行的,除了输入张量会立即分片。一旦张量被注释并包装在 XLAShardedTensor 中,它就可以作为 torch.Tensor 传递给现有的 PyTorch 操作和 nn.Module 层。这对于确保相同的 PyTorch 层和张量操作可以与 XLAShardedTensor 堆叠在一起非常重要。这意味着用户无需为分片计算重写现有的操作和模型代码。也就是说,XLAShardedTensor 将满足以下要求

  • XLAShardedTensortorch.Tensor 子类,可直接与原生 torch 操作和 module.layers 一起使用。我们使用 __torch_dispatch__XLAShardedTensor 发送到 XLA 后端。PyTorch/XLA 检索附加的分片注释以追踪图,并调用 XLA SPMDPartitioner。

  • 在内部,XLAShardedTensor(及其 global_tensor 输入)由 XLATensor 支持,后者具有特殊的数据结构,用于保存对分片设备数据的引用。

  • 在惰性执行之后,分片张量可以在主机上请求时被收集并物化回主机,作为 global_tensor(例如,打印全局张量的值)。

  • 本地分片的句柄在惰性执行后严格物化。XLAShardedTensor 公开 local_shards,以将可寻址设备上的本地分片作为 List[[XLAShard](https://github.com/pytorch/xla/blob/4e8e5511555073ce8b6d1a436bf808c9333dcac6/torch_xla/distributed/spmd/xla_sharded_tensor.py#L12)] 返回。

目前还在努力将 XLAShardedTensor 集成到 DistributedTensor API 中,以支持 XLA 后端 [RFC]。

DTensor 集成

PyTorch 在 2.1 中原型发布了 DTensor。我们正在将 PyTorch/XLA SPMD 集成到 DTensor API RFC 中。我们有一个 distribute_tensor 的概念验证集成,它调用 mark_sharding 注释 API 来分片张量及其使用 XLA 的计算

import torch
from torch.distributed import DeviceMesh, Shard, distribute_tensor

# distribute_tensor now works with `xla` backend using PyTorch/XLA SPMD.
mesh = DeviceMesh("xla", list(range(world_size)))
big_tensor = torch.randn(100000, 88)
my_dtensor = distribute_tensor(big_tensor, mesh, [Shard(0)])

此功能是实验性的,请继续关注即将发布的版本中的更多更新、示例和教程。

torch.compile 的激活分片

在 2.3 版本中,PyTorch/XLA 添加了自定义操作 dynamo_mark_sharding,该操作可用于在 torch.compile 区域中执行激活分片。这是我们正在进行的努力的一部分,旨在使 torch.compile + GSPMD 成为使用 PyTorch/XLA 进行模型推理的推荐方式。使用此自定义操作的示例

# Activation output sharding
device_ids = [i for i in range(self.num_devices)] # List[int]
mesh_shape = [self.num_devices//2, 1, 2] # List[int]
axis_names = "('data', 'model')" # string version of axis_names
partition_spec = "('data', 'model')" # string version of partition spec
torch.ops.xla.dynamo_mark_sharding(output, device_ids, mesh_shape, axis_names, partition_spec)

SPMD 调试工具

我们为 TPU/GPU/CPU 上使用单主机/多主机的 PyTorch/XLA SPMD 用户提供了一个 分片放置可视化调试工具:您可以使用 visualize_tensor_sharding 可视化分片张量,或者可以使用 visualize_sharding 可视化共享字符串。以下是关于 TPU 单主机 (v4-8) 使用 visualize_tensor_shardingvisualize_sharding 的两个代码示例

  • 使用的代码片段 visualize_tensor_sharding 和可视化结果

import rich

# Here, mesh is a 2x2 mesh with axes 'x' and 'y'
t = torch.randn(8, 4, device='xla')
xs.mark_sharding(t, mesh, ('x', 'y'))

# A tensor's sharding can be visualized using the `visualize_tensor_sharding` method
from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding
generated_table = visualize_tensor_sharding(t, use_color=False)
visualize_tensor_sharding example on TPU v4-8(single-host)
  • 使用的代码片段 visualize_sharding 和可视化结果

from torch_xla.distributed.spmd.debugging import visualize_sharding
sharding = '{devices=[2,2]0,1,2,3}'
generated_table = visualize_sharding(sharding, use_color=False)
visualize_sharding example on TPU v4-8(single-host)

您可以在 TPU/GPU/CPU 单主机上使用这些示例,并修改它们以在多主机上运行。您可以修改它以使用分片样式 tiledpartial_replicationreplicated

自动分片

我们正在推出一项新的 PyTorch/XLA SPMD 功能,称为 自动分片RFC。这是 r2.3nightly 中的一项实验性功能,支持 XLA:TPU 和单 TPUVM 主机。

可以通过以下方式之一启用 PyTorch/XLA 自动分片

  • 设置环境变量 XLA_AUTO_SPMD=1

  • 在代码开头调用 SPMD API

import torch_xla.runtime as xr
xr.use_spmd(auto=True)
  • 使用 auto-policyxla 调用 pytorch.distributed._tensor.distribute_module

import torch_xla.runtime as xr
from torch.distributed._tensor import DeviceMesh, distribute_module
from torch_xla.distributed.spmd import auto_policy

device_count = xr.global_runtime_device_count()
device_mesh = DeviceMesh("xla", list(range(device_count)))

# Currently, model should be loaded to xla device via distribute_module.
model = MyModule()  # nn.module
sharded_model = distribute_module(model, device_mesh, auto_policy)

(可选)可以设置以下选项/环境变量来控制基于 XLA 的自动分片传递的行为

  • XLA_AUTO_USE_GROUP_SHARDING:参数的组重新分片。默认设置。

  • XLA_AUTO_SPMD_MESH:用于自动分片的逻辑网格形状。例如,XLA_AUTO_SPMD_MESH=2,2 对应于具有 4 个全局设备的 2x2 网格。如果未设置,将使用 num_devices,1 的默认设备网格形状。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得解答

查看资源