• 教程 >
  • 分布式数据并行入门
快捷方式

分布式数据并行入门

创建日期:2019 年 4 月 23 日 | 最后更新:2024 年 10 月 30 日 | 最后验证:2024 年 11 月 5 日

作者: Shen Li

编辑: Joe Zhu, Chirag Pandya

注意

编辑github 中查看和编辑本教程。

先决条件

DistributedDataParallel (DDP) 是 PyTorch 中一个强大的模块,它允许你在多台机器上并行化模型,使其非常适合大规模深度学习应用。要使用 DDP,你需要启动多个进程,并为每个进程创建一个 DDP 实例。

但它是如何工作的呢?DDP 利用 torch.distributed 包中的集体通信来同步所有进程的梯度和缓冲区。这意味着每个进程将拥有自己模型的副本,但它们将协同工作,就像在单台机器上训练模型一样。

为了实现这一点,DDP 为模型中的每个参数注册了一个自动微分钩子。当运行反向传播时,这个钩子会触发,并在所有进程中同步梯度。这确保了每个进程拥有相同的梯度,然后使用这些梯度来更新模型。

有关 DDP 工作原理以及如何有效使用它的更多信息,请务必查看 DDP 设计说明。使用 DDP,你可以比以往更快、更高效地训练模型!

推荐使用 DDP 的方式是为每个模型副本启动一个进程。模型副本可以跨越多个设备。DDP 进程可以位于同一台机器上,也可以跨越多台机器。请注意,GPU 设备不能在 DDP 进程之间共享(即一个 GPU 对应一个 DDP 进程)。

在本教程中,我们将从 DDP 的基本用例开始,然后演示更高级的用例,包括模型检查点和将 DDP 与模型并行结合使用。

注意

本教程中的代码在具有 8 个 GPU 的服务器上运行,但可以轻松推广到其他环境。

DataParallelDistributedDataParallel 的比较

在深入探讨之前,让我们阐明为什么尽管 DistributedDataParallel 增加了复杂性,你仍会考虑使用它而不是 DataParallel

  • 首先,DataParallel 是单进程多线程的,但它只能在单台机器上工作。相比之下,DistributedDataParallel 是多进程的,支持单机和多机训练。由于线程间的 GIL 竞争、每迭代复制模型以及分散输入和收集输出引入的额外开销,即使在单台机器上,DataParallel 通常也比 DistributedDataParallel 慢。

  • 回顾先前的教程,如果你的模型太大无法放入单个 GPU,则必须使用 模型并行 将其分割到多个 GPU 上。DistributedDataParallel 可以与 模型并行 配合使用,而 DataParallel 目前不行。当 DDP 与模型并行结合使用时,每个 DDP 进程将使用模型并行,而所有进程将共同使用数据并行。

基本用例

要创建 DDP 模块,必须首先正确设置进程组。更多详细信息请参阅使用 PyTorch 编写分布式应用程序

import os
import sys
import tempfile
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp

from torch.nn.parallel import DistributedDataParallel as DDP

# On Windows platform, the torch.distributed package only
# supports Gloo backend, FileStore and TcpStore.
# For FileStore, set init_method parameter in init_process_group
# to a local file. Example as follow:
# init_method="file:///f:/libtmp/some_file"
# dist.init_process_group(
#    "gloo",
#    rank=rank,
#    init_method=init_method,
#    world_size=world_size)
# For TcpStore, same way as on Linux.

def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'

    # initialize the process group
    dist.init_process_group("gloo", rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()

现在,让我们创建一个玩具模块,用 DDP 包装它,并输入一些模拟数据。请注意,由于 DDP 在构造函数中将模型状态从 rank 0 进程广播到所有其他进程,因此你不必担心不同的 DDP 进程从不同的初始模型参数值开始。

class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(10, 10)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(10, 5)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))


def demo_basic(rank, world_size):
    print(f"Running basic DDP example on rank {rank}.")
    setup(rank, world_size)

    # create model and move it to GPU with id rank
    model = ToyModel().to(rank)
    ddp_model = DDP(model, device_ids=[rank])

    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

    optimizer.zero_grad()
    outputs = ddp_model(torch.randn(20, 10))
    labels = torch.randn(20, 5).to(rank)
    loss_fn(outputs, labels).backward()
    optimizer.step()

    cleanup()
    print(f"Finished running basic DDP example on rank {rank}.")


def run_demo(demo_fn, world_size):
    mp.spawn(demo_fn,
             args=(world_size,),
             nprocs=world_size,
             join=True)

如你所见,DDP 封装了底层的分布式通信细节,并提供了一个简洁的 API,就像它是本地模型一样。梯度同步通信发生在反向传播期间,并与反向计算重叠。当 backward() 返回时,param.grad 中已经包含同步后的梯度张量。对于基本用例,设置进程组只需要几行额外的代码。将 DDP 应用于更高级的用例时,需要注意一些事项。

处理速度倾斜

在 DDP 中,构造函数、前向传播和反向传播是分布式同步点。不同的进程应该启动相同数量的同步,并按相同的顺序到达这些同步点,并在大致相同的时间进入每个同步点。否则,快速进程可能会提前到达并在等待慢进程时超时。因此,用户有责任平衡进程间的工作负载分布。有时,由于网络延迟、资源竞争或不可预测的工作负载峰值等原因,处理速度倾斜是不可避免的。为避免在这些情况下发生超时,请确保在调用 init_process_group 时传递足够大的 timeout 值。

保存和加载检查点

在训练过程中使用 torch.savetorch.load 进行模块检查点并从检查点恢复是很常见的。更多详细信息请参阅保存和加载模型。使用 DDP 时,一个优化是只在一个进程中保存模型,然后在所有进程上加载,以减少写入开销。这之所以有效,是因为所有进程从相同的参数开始,并且梯度在反向传播中是同步的,因此优化器应该保持将参数设置为相同的值。如果使用此优化(即在一个进程上保存但在所有进程上恢复),请确保在保存完成之前没有进程开始加载。此外,加载模块时,需要提供适当的 map_location 参数,以防止进程踏入其他进程的设备。如果缺少 map_locationtorch.load 将首先将模块加载到 CPU,然后将每个参数复制到其保存的位置,这将导致同一台机器上的所有进程使用同一组设备。有关更高级的故障恢复和弹性支持,请参阅TorchElastic

def demo_checkpoint(rank, world_size):
    print(f"Running DDP checkpoint example on rank {rank}.")
    setup(rank, world_size)

    model = ToyModel().to(rank)
    ddp_model = DDP(model, device_ids=[rank])


    CHECKPOINT_PATH = tempfile.gettempdir() + "/model.checkpoint"
    if rank == 0:
        # All processes should see same parameters as they all start from same
        # random parameters and gradients are synchronized in backward passes.
        # Therefore, saving it in one process is sufficient.
        torch.save(ddp_model.state_dict(), CHECKPOINT_PATH)

    # Use a barrier() to make sure that process 1 loads the model after process
    # 0 saves it.
    dist.barrier()
    # configure map_location properly
    map_location = {'cuda:%d' % 0: 'cuda:%d' % rank}
    ddp_model.load_state_dict(
        torch.load(CHECKPOINT_PATH, map_location=map_location, weights_only=True))

    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

    optimizer.zero_grad()
    outputs = ddp_model(torch.randn(20, 10))
    labels = torch.randn(20, 5).to(rank)

    loss_fn(outputs, labels).backward()
    optimizer.step()

    # Not necessary to use a dist.barrier() to guard the file deletion below
    # as the AllReduce ops in the backward pass of DDP already served as
    # a synchronization.

    if rank == 0:
        os.remove(CHECKPOINT_PATH)

    cleanup()
    print(f"Finished running DDP checkpoint example on rank {rank}.")

将 DDP 与模型并行结合使用

DDP 也适用于多 GPU 模型。当使用大量数据训练大型模型时,DDP 包装多 GPU 模型尤其有用。

class ToyMpModel(nn.Module):
    def __init__(self, dev0, dev1):
        super(ToyMpModel, self).__init__()
        self.dev0 = dev0
        self.dev1 = dev1
        self.net1 = torch.nn.Linear(10, 10).to(dev0)
        self.relu = torch.nn.ReLU()
        self.net2 = torch.nn.Linear(10, 5).to(dev1)

    def forward(self, x):
        x = x.to(self.dev0)
        x = self.relu(self.net1(x))
        x = x.to(self.dev1)
        return self.net2(x)

将多 GPU 模型传递给 DDP 时,不能设置 device_idsoutput_device。输入和输出数据将由应用程序或模型的 forward() 方法放置在适当的设备上。

def demo_model_parallel(rank, world_size):
    print(f"Running DDP with model parallel example on rank {rank}.")
    setup(rank, world_size)

    # setup mp_model and devices for this process
    dev0 = rank * 2
    dev1 = rank * 2 + 1
    mp_model = ToyMpModel(dev0, dev1)
    ddp_mp_model = DDP(mp_model)

    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_mp_model.parameters(), lr=0.001)

    optimizer.zero_grad()
    # outputs will be on dev1
    outputs = ddp_mp_model(torch.randn(20, 10))
    labels = torch.randn(20, 5).to(dev1)
    loss_fn(outputs, labels).backward()
    optimizer.step()

    cleanup()
    print(f"Finished running DDP with model parallel example on rank {rank}.")


if __name__ == "__main__":
    n_gpus = torch.cuda.device_count()
    assert n_gpus >= 2, f"Requires at least 2 GPUs to run, but got {n_gpus}"
    world_size = n_gpus
    run_demo(demo_basic, world_size)
    run_demo(demo_checkpoint, world_size)
    world_size = n_gpus//2
    run_demo(demo_model_parallel, world_size)

使用 torch.distributed.run/torchrun 初始化 DDP

我们可以利用 PyTorch Elastic 来简化 DDP 代码并更容易地初始化任务。让我们继续使用 Toymodel 示例,并创建一个名为 elastic_ddp.py 的文件。

import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim

from torch.nn.parallel import DistributedDataParallel as DDP

class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(10, 10)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(10, 5)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))


def demo_basic():
    torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
    dist.init_process_group("nccl")
    rank = dist.get_rank()
    print(f"Start running basic DDP example on rank {rank}.")
    # create model and move it to GPU with id rank
    device_id = rank % torch.cuda.device_count()
    model = ToyModel().to(device_id)
    ddp_model = DDP(model, device_ids=[device_id])
    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

    optimizer.zero_grad()
    outputs = ddp_model(torch.randn(20, 10))
    labels = torch.randn(20, 5).to(device_id)
    loss_fn(outputs, labels).backward()
    optimizer.step()
    dist.destroy_process_group()
    print(f"Finished running basic DDP example on rank {rank}.")

if __name__ == "__main__":
    demo_basic()

然后可以在所有节点上运行 torch elastic/torchrun 命令来初始化上面创建的 DDP 任务

torchrun --nnodes=2 --nproc_per_node=8 --rdzv_id=100 --rdzv_backend=c10d --rdzv_endpoint=$MASTER_ADDR:29400 elastic_ddp.py

在上面的示例中,我们在两台主机上运行 DDP 脚本,并在每台主机上运行 8 个进程。也就是说,我们在 16 个 GPU 上运行这个任务。请注意,$MASTER_ADDR 在所有节点上必须相同。

这里的 torchrun 将启动 8 个进程,并在启动它的节点上的每个进程上调用 elastic_ddp.py,但用户还需要使用像 slurm 这样的集群管理工具来实际在 2 个节点上运行此命令。

例如,在启用 SLURM 的集群上,我们可以编写一个脚本来运行上面的命令并设置 MASTER_ADDR 如下

export MASTER_ADDR=$(scontrol show hostname ${SLURM_NODELIST} | head -n 1)

然后我们就可以使用 SLURM 命令运行这个脚本:srun --nodes=2 ./torchrun_script.sh

这只是一个例子;你可以选择自己的集群调度工具来启动 torchrun 任务。

有关弹性运行的更多信息,请参阅快速入门文档

文档

查阅 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并解答问题

查看资源