• 文档 >
  • 如何进行 DistributedDataParallel(DDP)
快捷方式

如何进行 DistributedDataParallel(DDP)

本文档展示了如何在 xla 中使用 torch.nn.parallel.DistributedDataParallel,并进一步描述了它与原生 xla 数据并行方法的区别。您可以在此处找到一个最小可运行示例。

背景 / 动机

长期以来,客户一直要求能够在 xla 中使用 PyTorch 的 DistributedDataParallel API。我们在此将其作为实验性功能启用。

如何使用 DistributedDataParallel

对于那些从 PyTorch eager 模式切换到 XLA 的用户,以下是将您的 eager DDP 模型转换为 XLA 模型所需做的所有更改。我们假设您已经知道如何在单个设备上使用 XLA。

  1. 导入 xla 特定的分布式包

    import torch_xla
    import torch_xla.runtime as xr
    import torch_xla.distributed.xla_backend
    
  2. 初始化 xla 进程组,类似于其他进程组,如 nccl 和 gloo。

    dist.init_process_group("xla", rank=rank, world_size=world_size)
    
  3. 如果您需要,可以使用 xla 特定的 API 获取 rank 和 world_size。

    new_rank = xr.global_ordinal()
    world_size = xr.world_size()
    
  4. 使用 DDP 包装模型。

    ddp_model = DDP(model, gradient_as_bucket_view=True)
    
  5. 最后使用 xla 特定的启动器启动您的模型。

    torch_xla.launch(demo_fn)
    

这里我们将所有内容放在一起(该示例实际上取自DDP 教程)。您的编码方式与 eager 体验非常相似。只需在单个设备上进行 xla 特定的调整,再加上对脚本的上述五个更改。

import os
import sys
import tempfile
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim

from torch.nn.parallel import DistributedDataParallel as DDP

# additional imports for xla
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
import torch_xla.distributed.xla_backend

def setup(rank, world_size):
    os.environ['PJRT_DEVICE'] = 'TPU'

    # initialize the xla process group
    dist.init_process_group("xla", rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()

class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(10, 1000000)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(1000000, 5)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))

def demo_basic(rank):
    # xla specific APIs to get rank, world_size.
    new_rank = xr.global_ordinal()
    assert new_rank == rank
    world_size = xr.world_size()

    print(f"Running basic DDP example on rank {rank}.")
    setup(rank, world_size)

    # create model and move it to XLA device
    device = xm.xla_device()
    model = ToyModel().to(device)
    ddp_model = DDP(model, gradient_as_bucket_view=True)

    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

    optimizer.zero_grad()
    outputs = ddp_model(torch.randn(20, 10).to(device))
    labels = torch.randn(20, 5).to(device)
    loss_fn(outputs, labels).backward()
    optimizer.step()
    # xla specific API to execute the graph
    xm.mark_step()

    cleanup()


def run_demo(demo_fn):
    # xla specific launcher
    torch_xla.launch(demo_fn)

if __name__ == "__main__":
    run_demo(demo_basic)

基准测试

使用虚假数据的 Resnet50

以下结果是通过在 TPU VM V3-8 环境中使用 ToT PyTorch 和 PyTorch/XLA 的命令收集的

python test/test_train_mp_imagenet.py --fake_data --model=resnet50 --num_epochs=1

统计指标是通过使用此pull request中的脚本生成的。速率单位为每秒图像数。

类型 平均值 中位数 90% 标准差 CV
xm.optimizer_step 418.54 419.22 430.40 9.76 0.02
DDP 395.97 395.54 407.13 7.60 0.02

我们的分布式数据并行原生方法和 DistributedDataParallel 包装器之间的性能差异为:1 - 395.97 / 418.54 = 5.39%。考虑到 DDP 包装器在跟踪 DDP 运行时引入了额外的开销,这个结果看起来是合理的。

使用虚假数据的 MNIST

以下结果是通过在 TPU VM V3-8 环境中使用 ToT PyTorch 和 PyTorch/XLA 的命令 python test/test_train_mp_mnist.py --fake_data 收集的。统计指标是通过使用此pull request中的脚本生成的。速率单位为每秒图像数。

类型 平均值 中位数 90% 标准差 CV
xm.optimizer_step 17864.19 20108.96 24351.74 5866.83 0.33
DDP 10701.39 11770.00 14313.78 3102.92 0.29

我们的分布式数据并行原生方法和 DistributedDataParallel 包装器之间的性能差异为:1 - 14313.78 / 24351.74 = 41.22%。由于数据集较小,并且前几轮受到数据加载的严重影响,因此我们在此处比较了 90%。考虑到模型很小,这种减速是巨大的,但也是有道理的。额外的 DDP 运行时跟踪开销很难摊销。

使用真实数据的 MNIST

以下结果是通过在 TPU VM V3-8 环境中使用 ToT PyTorch 和 PyTorch/XLA 的命令 n 收集的

python test/test_train_mp_mnist.py --logdir mnist/ o.

我们可以观察到,即使 DDP 包装器最终仍实现了 97.48% 的高准确率,但它的收敛速度比原生 XLA 方法慢。(原生方法实现了 99%。)

免责声明

此功能仍处于实验阶段,并且正在积极开发中。请谨慎使用,并随时向xla github 仓库提交任何错误报告。对于那些对原生 xla 数据并行方法感兴趣的人,这里是教程

以下是一些正在调查的已知问题:* 与 torch.utils.data.DataLoader 一起使用时存在一些问题。test_train_mp_mnist.py 使用真实数据时会在退出前崩溃。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得解答

查看资源