• 文档 >
  • 使用 SPMD 的全分片数据并行 (Fully Sharded Data Parallel)
快捷方式

使用 SPMD 的全分片数据并行 (Fully Sharded Data Parallel)

通过 SPMD 实现的全分片数据并行(FSDPv2)是一种将著名的 FSDP 算法重新表达为 SPMD 的实用工具。是一项实验性功能,旨在为用户提供熟悉的接口,以便享受 SPMD 带来的所有优势。设计文档在此处

在继续之前,请查阅SPMD 用户指南。您还可以在此处找到一个最小的可运行示例。

示例用法

import torch
import torch_xla.core.xla_model as xm
import torch_xla.distributed.spmd as xs
from torch_xla.experimental.spmd_fully_sharded_data_parallel import SpmdFullyShardedDataParallel as FSDPv2

# Define the mesh following common SPMD practice
num_devices = xr.global_runtime_device_count()
mesh_shape = (num_devices, 1)
device_ids = np.array(range(num_devices))
# To be noted, the mesh must have an axis named 'fsdp', which the weights and activations will be sharded on.
mesh = xs.Mesh(device_ids, mesh_shape, ('fsdp', 'model'))

# Shard the input, and assume x is a 2D tensor.
x = xs.mark_sharding(x, mesh, ('fsdp', None))

# As normal FSDP, but an extra mesh is needed.
model = FSDPv2(my_module, mesh)
optim = torch.optim.Adam(model.parameters(), lr=0.0001)
output = model(x, y)
loss = output.sum()
loss.backward()
optim.step()

也可以单独对各个层进行分片,并由外部包装器处理剩余的参数。以下是一个自动包装每个 DecoderLayer 的示例。

from torch_xla.distributed.fsdp.wrap import transformer_auto_wrap_policy

# Apply FSDP sharding on each DecoderLayer layer.
auto_wrap_policy = functools.partial(
    transformer_auto_wrap_policy,
    transformer_layer_cls={
        decoder_only_model.DecoderLayer
    },
)
model = FSDPv2(
    model, mesh=mesh, auto_wrap_policy=auto_wrap_policy)

分片输出

为了确保 XLA 编译器正确实现 FSDP 算法,我们需要对权重和激活进行分片。这意味着对前向传播方法的输出进行分片。由于前向函数的输出可能不同,在您的模块输出不属于以下任何一种情况时,我们提供了 shard_output 来对激活进行分片:1. 单个张量 2. 第 0 个元素是激活的张量元组。

示例用法

def shard_output(output, mesh):
    xs.mark_sharding(output.logits, mesh, ('fsdp', None, None))

model = FSDPv2(my_module, mesh, shard_output)

梯度检查点 (Gradient checkpointing)

当前,梯度检查点需要在 FSDP 包装器之前应用于模块。否则,递归进入子模块将导致无限循环。我们将在未来的版本中解决此问题。

示例用法

from torch_xla.distributed.fsdp import checkpoint_module

model = FSDPv2(checkpoint_module(my_module), mesh)

HuggingFace Llama 2 示例

我们有一个 HF Llama 2 的分支,在此处演示了一种潜在的集成。

文档

查阅 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题解答

查看资源