• 文档 >
  • PyTorch XLA 中的完全分片数据并行
快捷方式

PyTorch XLA 中的完全分片数据并行

PyTorch XLA 中的完全分片数据并行 (FSDP) 是一种用于在数据并行工作节点之间分片模块参数的实用工具。

用法示例

import torch
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP

model = FSDP(my_module)
optim = torch.optim.Adam(model.parameters(), lr=0.0001)
output = model(x, y)
loss = output.sum()
loss.backward()
optim.step()

也可以单独分片每个层,并由外部包装器处理任何剩余参数。

注意:XlaFullyShardedDataParallel 类支持 https://arxiv.org/abs/1910.02054 中描述的 ZeRO-2 优化器(分片梯度和优化器状态)和 ZeRO-3 优化器(分片参数、梯度和优化器状态)。ZeRO-3 优化器应通过带有 reshard_after_forward=True 的嵌套 FSDP 实现。请参阅 test/test_train_mp_mnist_fsdp_with_ckpt.pytest/test_train_mp_imagenet_fsdp.py 获取示例。* 对于无法放入单个 TPU 内存或主机 CPU 内存的大型模型,应在构建子模块时与内部 FSDP 包装交错进行。请参阅 FSDPViTModel 获取示例。提供了一个简单的包装器 checkpoint_module(基于 https://github.com/pytorch/xla/pull/3524 中的 torch_xla.utils.checkpoint.checkpoint),用于在给定的 nn.Module 实例上执行梯度检查点。请参阅 test/test_train_mp_mnist_fsdp_with_ckpt.pytest/test_train_mp_imagenet_fsdp.py 获取示例。子模块自动包装:除了手动进行嵌套 FSDP 包装之外,还可以指定一个 auto_wrap_policy 参数,以自动使用内部 FSDP 包装子模块。torch_xla.distributed.fsdp.wrap 中的 size_based_auto_wrap_policy 是一个 auto_wrap_policy 可调用对象的示例,此策略会包装参数数量大于 100M 的层。torch_xla.distributed.fsdp.wrap 中的 transformer_auto_wrap_policy 是一个用于类似 transformer 模型架构的 auto_wrap_policy 可调用对象的示例。

例如,要自动使用内部 FSDP 包装所有 torch.nn.Conv2d 子模块,可以使用

from torch_xla.distributed.fsdp.wrap import transformer_auto_wrap_policy
auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={torch.nn.Conv2d})

此外,还可以指定 auto_wrapper_callable 参数,为子模块使用自定义可调用包装器(默认包装器就是 XlaFullyShardedDataParallel 类本身)。例如,可以使用以下方式对每个自动包装的子模块应用梯度检查点(即激活检查点/再物化)。

from torch_xla.distributed.fsdp import checkpoint_module
auto_wrapper_callable = lambda m, *args, **kwargs: XlaFullyShardedDataParallel(
    checkpoint_module(m), *args, **kwargs)
  • 在进行优化器步进时,直接调用 optimizer.step,而不要调用 xm.optimizer_step。后者会在各个 rank 之间归约梯度,这对于 FSDP 是不必要的(因为参数已经分片)。

  • 在训练期间保存模型和优化器检查点时,每个训练进程需要保存其自身的(分片)模型和优化器状态字典检查点(在 xm.save 中使用 master_only=False 并为每个 rank 设置不同的路径)。恢复训练时,需要加载相应 rank 的检查点。

  • 请同时保存 model.get_shard_metadata()model.state_dict(),并使用 consolidate_sharded_model_checkpoints 将分片的模型检查点合并成一个完整的模型状态字典。请参阅 test/test_train_mp_mnist_fsdp_with_ckpt.py 获取示例。

ckpt = {
    'model': model.state_dict(),
    'shard_metadata': model.get_shard_metadata(),
    'optimizer': optimizer.state_dict(),
}
ckpt_path = f'/tmp/rank-{xr.global_ordinal()}-of-{xr.world_size()}.pth'
xm.save(ckpt, ckpt_path, master_only=False)
  • 检查点合并脚本也可以按如下方式从命令行启动。

# consolidate the saved checkpoints via command line tool
python3 -m torch_xla.distributed.fsdp.consolidate_sharded_ckpts \
  --ckpt_prefix /path/to/your_sharded_checkpoint_files \
  --ckpt_suffix "_rank-*-of-*.pth"

此类的实现很大程度上受到 https://fairscale.readthedocs.io/en/stable/api/nn/fsdp.htmlfairscale.nn.FullyShardedDataParallel 结构的启发,并主要遵循其结构。与 fairscale.nn.FullyShardedDataParallel 的最大区别之一是,在 XLA 中我们没有显式的参数存储,因此这里我们采用不同的方法来为 ZeRO-3 释放完整的参数。

MNIST 和 ImageNet 上的训练脚本示例

安装

FSDP 可用于 PyTorch/XLA 1.12 版本及更高版本的每夜构建。请参阅 https://github.com/pytorch/xla#-available-images-and-wheels 获取安装指南。

克隆 PyTorch/XLA 仓库

git clone --recursive https://github.com/pytorch/pytorch
cd pytorch/
git clone --recursive https://github.com/pytorch/xla.git
cd ~/

在 v3-8 TPU 上训练 MNIST

经过 2 个 epoch,准确率达到约 98.9%

python3 ~/pytorch/xla/test/test_train_mp_mnist_fsdp_with_ckpt.py \
  --batch_size 16 --drop_last --num_epochs 2 \
  --use_nested_fsdp --use_gradient_checkpointing

此脚本会在最后自动测试检查点合并。您也可以通过以下方式手动合并分片检查点:

# consolidate the saved checkpoints via command line tool
python3 -m torch_xla.distributed.fsdp.consolidate_sharded_ckpts \
  --ckpt_prefix /tmp/mnist-fsdp/final_ckpt \
  --ckpt_suffix "_rank-*-of-*.pth"

在 v3-8 TPU 上使用 ResNet-50 训练 ImageNet

经过 100 个 epoch,准确率达到约 75.9%;将 ImageNet-1k 下载到 /datasets/imagenet-1k

python3 ~/pytorch/xla/test/test_train_mp_imagenet_fsdp.py \
  --datadir /datasets/imagenet-1k --drop_last \
  --model resnet50 --test_set_batch_size 64 --eval_interval 10 \
  --lr 0.4 --batch_size 128 --num_warmup_epochs 5 --lr_scheduler_divide_every_n_epochs 30 --lr_scheduler_divisor 10 --num_epochs 100 \
  --use_nested_fsdp

您还可以添加 --use_gradient_checkpointing(需要与 --use_nested_fsdp--auto_wrap_policy 一起使用),以在残差块上应用梯度检查点。

在 TPU Pod 上训练脚本示例(具有 100 亿参数)

要训练无法放入单个 TPU 的大型模型,应在构建整个模型时应用自动包装或手动使用内部 FSDP 包装子模块,以实现 ZeRO-3 算法。

请参阅 https://github.com/ronghanghu/vit_10b_fsdp_example 获取使用此 XLA FSDP PR 对 Vision Transformer (ViT) 模型进行分片训练的示例。

文档

查阅 PyTorch 全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获取问题解答

查看资源