• 文档 >
  • 如何使用 DistributedDataParallel
快捷方式

如何使用 DistributedDataParallel

本文档展示了如何在 xla 中使用 torch.nn.parallel.DistributedDataParallel,并进一步描述了它与原生 xla 数据并行方法的区别。

背景/动机

客户长期以来一直要求能够将 PyTorch 的 DistributedDataParallel API 与 xla 一起使用。在这里,我们将其作为一项实验性功能启用。

如何使用 DistributedDataParallel

对于那些从 PyTorch Eager 模式切换到 XLA 的用户,以下列出了将您的 Eager DDP 模型转换为 XLA 模型所需进行的所有更改。我们假设您已经知道如何在单个设备上使用 XLA 使用 XLA

  1. 导入特定于 xla 的分布式包

import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_backend
  1. 初始化类似于其他进程组(如 nccl 和 gloo)的 xla 进程组。

dist.init_process_group("xla", rank=rank, world_size=world_size)
  1. 如果需要,使用特定于 xla 的 API 获取 rank 和 world_size。

new_rank = xm.get_ordinal()
world_size = xm.xrt_world_size()
  1. gradient_as_bucket_view=True 传递给 DDP 包装器。

ddp_model = DDP(model, gradient_as_bucket_view=True)
  1. 最后使用特定于 xla 的启动器启动您的模型。

xmp.spawn(demo_fn)

在这里,我们将所有内容整合在一起(此示例实际上取自 DDP 教程)。您的编码方式与 Eager 体验非常相似。只是在单个设备上增加了特定于 xla 的调整,以及对您的脚本进行上述五项更改。

import os
import sys
import tempfile
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim

from torch.nn.parallel import DistributedDataParallel as DDP

# additional imports for xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_backend
import torch_xla.distributed.xla_multiprocessing as xmp

def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'

    # initialize the xla process group
    dist.init_process_group("xla", rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()

class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(10, 1000000)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(1000000, 5)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))

def demo_basic(rank):
    # xla specific APIs to get rank, world_size.
    new_rank = xm.get_ordinal()
    assert new_rank == rank
    world_size = xm.xrt_world_size()

    print(f"Running basic DDP example on rank {rank}.")
    setup(rank, world_size)

    # create model and move it to XLA device
    device = xm.xla_device()
    model = ToyModel().to(device)
    # currently, graident_as_bucket_view is needed to make DDP work for xla
    ddp_model = DDP(model, gradient_as_bucket_view=True)

    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

    optimizer.zero_grad()
    outputs = ddp_model(torch.randn(20, 10).to(device))
    labels = torch.randn(20, 5).to(device)
    loss_fn(outputs, labels).backward()
    optimizer.step()
    # xla specific API to execute the graph
    xm.mark_step()

    cleanup()


def run_demo(demo_fn):
    # xla specific launcher
    xmp.spawn(demo_fn)

if __name__ == "__main__":
    run_demo(demo_basic)

基准测试

使用伪数据的 Resnet50

以下结果是在 TPU VM V3-8 环境中使用 ToT PyTorch 和 PyTorch/XLA 通过以下命令收集的:python test/test_train_mp_imagenet.py --fake_data --model=resnet50 --num_epochs=1。统计指标是使用此 pull request 中的脚本生成的。速率单位为每秒图像数。

类型 平均值 中位数 第 90 个百分位数 标准差 变异系数
xm.optimizer_step 418.54 419.22 430.40 9.76 0.02
DDP 395.97 395.54 407.13 7.60 0.02

我们的原生分布式数据并行方法与 DistributedDataParallel 包装器之间的性能差异为:1 - 395.97 / 418.54 = 5.39%。考虑到 DDP 包装器在跟踪 DDP 运行时引入了额外的开销,因此此结果似乎合理。

使用伪数据的 MNIST

以下结果是在 TPU VM V3-8 环境中使用 ToT PyTorch 和 PyTorch/XLA 通过以下命令收集的:python test/test_train_mp_mnist.py --fake_data。统计指标是使用此 pull request 中的脚本生成的。速率单位为每秒图像数。

类型 平均值 中位数 第 90 个百分位数 标准差 变异系数
xm.optimizer_step 17864.19 20108.96 24351.74 5866.83 0.33
DDP 10701.39 11770.00 14313.78 3102.92 0.29

我们的原生分布式数据并行方法与 DistributedDataParallel 包装器之间的性能差异为:1 - 14313.78 / 24351.74 = 41.22%。由于数据集较小,并且前几轮严重受数据加载的影响,因此我们在此处比较第 90 个百分位数。这种减速幅度很大,但考虑到模型很小,这是有道理的。额外的 DDP 运行时跟踪开销难以摊销。

使用真实数据的 MNIST

以下结果是在 TPU VM V3-8 环境中使用 ToT PyTorch 和 PyTorch/XLA 通过以下命令收集的:python test/test_train_mp_mnist.py --logdir mnist/

learning_curves

我们可以观察到,尽管 DDP 包装器最终仍实现了 97.48% 的高准确率,但其收敛速度比原生 XLA 方法慢。(原生方法实现了 99% 的准确率。)

免责声明

此功能仍处于实验阶段,并且正在积极开发中。请谨慎使用,并随时将任何错误报告给 xla github 仓库。对于那些对原生 xla 数据并行方法感兴趣的用户,这里有一个 教程

以下是一些正在调查中的已知问题

  • 需要强制执行 gradient_as_bucket_view=True

  • torch.utils.data.DataLoader 一起使用时存在一些问题。​​test_train_mp_mnist.py 在使用真实数据时会在退出前崩溃。

PyTorch XLA 中的全分片数据并行 (FSDP)

PyTorch XLA 中的全分片数据并行 (FSDP) 是一种用于跨数据并行工作器分片模块参数的实用程序。

示例用法

import torch
import torch_xla.core.xla_model as xm
from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP

model = FSDP(my_module)
optim = torch.optim.Adam(model.parameters(), lr=0.0001)
output = model(x, y)
loss = output.sum()
loss.backward()
optim.step()

还可以分别分片各个层,并让外部包装器处理任何剩余的参数。

注意

  • XlaFullyShardedDataParallel 类支持 ZeRO-2 优化器(分片梯度和优化器状态)和 ZeRO-3 优化器(分片参数、梯度和优化器状态),如 https://arxiv.org/abs/1910.02054 中所述。

    • ZeRO-3 优化器应通过嵌套 FSDP 实现,并使用 reshard_after_forward=True。有关示例,请参阅 test/test_train_mp_mnist_fsdp_with_ckpt.pytest/test_train_mp_imagenet_fsdp.py

    • 对于无法容纳在单个 TPU 内存或主机 CPU 内存中的大型模型,应将子模块构造与内部 FSDP 包装交错进行。有关示例,请参阅 ``FSDPViTModel` <https://github.com/ronghanghu/vit_10b_fsdp_example/blob/master/run_vit_training.py>`_。

  • 提供了一个简单的包装器 checkpoint_module(基于 https://github.com/pytorch/xla/pull/3524 中的 torch_xla.utils.checkpoint.checkpoint),以便对给定的 nn.Module 实例执行 梯度检查点。有关示例,请参阅 test/test_train_mp_mnist_fsdp_with_ckpt.pytest/test_train_mp_imagenet_fsdp.py

  • 自动包装子模块:无需手动嵌套 FSDP 包装,还可以指定 auto_wrap_policy 参数来自动使用内部 FSDP 包装子模块。size_based_auto_wrap_policy(位于 torch_xla.distributed.fsdp.wrap 中)是 auto_wrap_policy 可调用的一个示例,此策略包装参数数量大于 100M 的层。transformer_auto_wrap_policy(位于 torch_xla.distributed.fsdp.wrap 中)是用于类似 Transformer 模型架构的 auto_wrap_policy 可调用的一个示例。

例如,要自动使用内部 FSDP 包装所有 torch.nn.Conv2d 子模块,可以使用

from torch_xla.distributed.fsdp.wrap import transformer_auto_wrap_policy
auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={torch.nn.Conv2d})

此外,还可以指定 auto_wrapper_callable 参数来为子模块使用自定义的可调用包装器(默认包装器只是 XlaFullyShardedDataParallel 类本身)。例如,可以使用以下方法将梯度检查点(即激活检查点/重新计算)应用于每个自动包装的子模块。

from torch_xla.distributed.fsdp import checkpoint_module
auto_wrapper_callable = lambda m, *args, **kwargs: XlaFullyShardedDataParallel(
    checkpoint_module(m), *args, **kwargs)
  • 在执行优化器步骤时,直接调用 optimizer.step,不要调用 xm.optimizer_step。后者会跨进程减少梯度,这对于 FSDP(参数已经分片)来说是不必要的。

  • 在训练过程中保存模型和优化器检查点时,每个训练进程都需要保存其自己的(分片)模型和优化器状态字典的检查点(使用 master_only=False,并在 xm.save 中为每个进程设置不同的路径)。恢复时,需要加载对应进程的检查点。

  • 请同时保存 model.get_shard_metadata()model.state_dict(),如下所示,并使用 consolidate_sharded_model_checkpoints 将分片模型检查点拼接成完整的模型状态字典。请参考 test/test_train_mp_mnist_fsdp_with_ckpt.py 获取示例。 .. code-block:: python3

    ckpt = {

    ‘model’: model.state_dict(), ‘shard_metadata’: model.get_shard_metadata(), ‘optimizer’: optimizer.state_dict(),

    } ckpt_path = f’/tmp/rank-{xm.get_ordinal()}-of-{xm.xrt_world_size()}.pth’ xm.save(ckpt, ckpt_path, master_only=False)

  • 检查点合并脚本也可以从命令行启动,如下所示。 .. code-block:: bash

    # 通过命令行工具合并保存的检查点 python3 -m torch_xla.distributed.fsdp.consolidate_sharded_ckpts –ckpt_prefix /path/to/your_sharded_checkpoint_files –ckpt_suffix “_rank--of-.pth”

此类的实现很大程度上受到 fairscale.nn.FullyShardedDataParallel 的启发,并且在很大程度上遵循其结构,该类可在 https://fairscale.readthedocs.io/en/stable/api/nn/fsdp.html 中找到。与 fairscale.nn.FullyShardedDataParallel 最大的区别之一在于,在 XLA 中我们没有显式参数存储,因此在这里我们采用不同的方法来释放 ZeRO-3 的完整参数。


MNIST 和 ImageNet 上的示例训练脚本

安装

FSDP 可用于 PyTorch/XLA 1.12 版本及更新的 nightly 版本。有关安装指南,请参阅 https://github.com/pytorch/xla#-available-images-and-wheels

克隆 PyTorch/XLA 仓库

git clone --recursive https://github.com/pytorch/pytorch
cd pytorch/
git clone --recursive https://github.com/pytorch/xla.git
cd ~/

在 v3-8 TPU 上训练 MNIST

它在 2 个 epoch 内获得了大约 98.9 的准确率

python3 ~/pytorch/xla/test/test_train_mp_mnist_fsdp_with_ckpt.py \
  --batch_size 16 --drop_last --num_epochs 2 \
  --use_nested_fsdp --use_gradient_checkpointing

此脚本在最后自动测试检查点合并。您也可以通过以下方式手动合并分片检查点

# consolidate the saved checkpoints via command line tool
python3 -m torch_xla.distributed.fsdp.consolidate_sharded_ckpts \
  --ckpt_prefix /tmp/mnist-fsdp/final_ckpt \
  --ckpt_suffix "_rank-*-of-*.pth"

在 v3-8 TPU 上使用 ResNet-50 训练 ImageNet

它在 100 个 epoch 内获得了大约 75.9 的准确率;将 ImageNet-1k 下载到 /datasets/imagenet-1k

python3 ~/pytorch/xla/test/test_train_mp_imagenet_fsdp.py \
  --datadir /datasets/imagenet-1k --drop_last \
  --model resnet50 --test_set_batch_size 64 --eval_interval 10 \
  --lr 0.4 --batch_size 128 --num_warmup_epochs 5 --lr_scheduler_divide_every_n_epochs 30 --lr_scheduler_divisor 10 --num_epochs 100 \
  --use_nested_fsdp

您还可以添加 --use_gradient_checkpointing(需要与 --use_nested_fsdp--auto_wrap_policy 结合使用)以在残差块上应用梯度检查点。


在 TPU Pod 上的示例训练脚本(具有 100 亿个参数)

要训练无法容纳在单个 TPU 中的大型模型,应在构建整个模型时应用自动包装或手动包装具有内部 FSDP 的子模块以实现 ZeRO-3 算法。

有关使用此 XLA FSDP PR 进行视觉转换器 (ViT) 模型分片训练的示例,请参阅 https://github.com/ronghanghu/vit_10b_fsdp_example

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取适合初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源