Distributed Data Parallel 入门¶
创建于:2019 年 4 月 23 日 | 最后更新:2024 年 10 月 30 日 | 最后验证:2024 年 11 月 05 日
作者: Shen Li
编辑: Joe Zhu, Chirag Pandya
注意
在 github 上查看和编辑本教程。
先决条件
DistributedDataParallel (DDP) 是 PyTorch 中一个强大的模块,允许您跨多台机器并行化您的模型,使其非常适合大规模深度学习应用程序。要使用 DDP,您需要生成多个进程,并为每个进程创建一个 DDP 实例。
但它是如何工作的呢?DDP 使用来自 torch.distributed 包的集体通信来同步所有进程之间的梯度和缓冲区。这意味着每个进程都将拥有模型的副本,但它们将协同工作来训练模型,就像它在单台机器上一样。
为了实现这一点,DDP 为模型中的每个参数注册一个 autograd 钩子。当运行反向传播时,此钩子会触发并在所有进程之间触发梯度同步。这确保了每个进程都具有相同的梯度,然后用于更新模型。
有关 DDP 如何工作以及如何有效使用它的更多信息,请务必查看 DDP 设计注释。借助 DDP,您可以比以往更快、更高效地训练模型!
推荐使用 DDP 的方法是为每个模型副本生成一个进程。模型副本可以跨多个设备。DDP 进程可以放置在同一台机器上或跨机器。请注意,GPU 设备不能在 DDP 进程之间共享(即,一个 DDP 进程对应一个 GPU)。
在本教程中,我们将从基本的 DDP 用例开始,然后演示更高级的用例,包括检查点模型以及将 DDP 与模型并行相结合。
注意
本教程中的代码在 8-GPU 服务器上运行,但它可以很容易地推广到其他环境。
DataParallel
和 DistributedDataParallel
之间的比较¶
在我们深入探讨之前,让我们澄清一下,尽管 DistributedDataParallel
增加了复杂性,但您为什么要考虑使用它而不是 DataParallel
首先,
DataParallel
是单进程、多线程的,但它仅在单台机器上工作。相比之下,DistributedDataParallel
是多进程的,并且支持单机和多机训练。由于线程之间的 GIL 争用、每次迭代复制模型以及 scattering 输入和 gathering 输出引入的额外开销,即使在单台机器上,DataParallel
通常也比DistributedDataParallel
慢。回顾 之前的教程,如果您的模型太大而无法容纳在单个 GPU 上,则必须使用模型并行将其拆分到多个 GPU 上。
DistributedDataParallel
与模型并行一起工作,而DataParallel
目前不支持。当 DDP 与模型并行结合使用时,每个 DDP 进程将使用模型并行,所有进程共同将使用数据并行。
基本用例¶
要创建 DDP 模块,您必须首先正确设置进程组。更多详细信息可以在 使用 PyTorch 编写分布式应用程序 中找到。
import os
import sys
import tempfile
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
# On Windows platform, the torch.distributed package only
# supports Gloo backend, FileStore and TcpStore.
# For FileStore, set init_method parameter in init_process_group
# to a local file. Example as follow:
# init_method="file:///f:/libtmp/some_file"
# dist.init_process_group(
# "gloo",
# rank=rank,
# init_method=init_method,
# world_size=world_size)
# For TcpStore, same way as on Linux.
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
# initialize the process group
dist.init_process_group("gloo", rank=rank, world_size=world_size)
def cleanup():
dist.destroy_process_group()
现在,让我们创建一个玩具模块,用 DDP 包装它,并给它一些虚拟输入数据。请注意,由于 DDP 在 DDP 构造函数中将模型状态从 rank 0 进程广播到所有其他进程,因此您无需担心不同的 DDP 进程从不同的初始模型参数值开始。
class ToyModel(nn.Module):
def __init__(self):
super(ToyModel, self).__init__()
self.net1 = nn.Linear(10, 10)
self.relu = nn.ReLU()
self.net2 = nn.Linear(10, 5)
def forward(self, x):
return self.net2(self.relu(self.net1(x)))
def demo_basic(rank, world_size):
print(f"Running basic DDP example on rank {rank}.")
setup(rank, world_size)
# create model and move it to GPU with id rank
model = ToyModel().to(rank)
ddp_model = DDP(model, device_ids=[rank])
loss_fn = nn.MSELoss()
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
optimizer.zero_grad()
outputs = ddp_model(torch.randn(20, 10))
labels = torch.randn(20, 5).to(rank)
loss_fn(outputs, labels).backward()
optimizer.step()
cleanup()
print(f"Finished running basic DDP example on rank {rank}.")
def run_demo(demo_fn, world_size):
mp.spawn(demo_fn,
args=(world_size,),
nprocs=world_size,
join=True)
如您所见,DDP 封装了较低级别的分布式通信细节,并提供了干净的 API,就像它是一个本地模型一样。梯度同步通信发生在反向传播期间,并与反向计算重叠。当 backward()
返回时,param.grad
已经包含同步的梯度张量。对于基本用例,DDP 只需要几行代码来设置进程组。当将 DDP 应用于更高级的用例时,一些注意事项需要谨慎。
处理速度偏差¶
在 DDP 中,构造函数、前向传播和反向传播是分布式同步点。不同的进程应启动相同数量的同步,并以相同的顺序到达这些同步点,并在大致相同的时间进入每个同步点。否则,快速进程可能会提前到达并在等待落后者时超时。因此,用户负责平衡跨进程的工作负载分配。有时,由于网络延迟、资源争用或不可预测的工作负载峰值等原因,处理速度偏差是不可避免的。为了避免在这些情况下超时,请确保在调用 init_process_group 时传递足够大的 timeout
值。
保存和加载检查点¶
通常使用 torch.save
和 torch.load
在训练期间检查点模块并从检查点恢复。有关更多详细信息,请参阅 保存和加载模型。使用 DDP 时,一种优化是在一个进程中保存模型,然后在所有进程上加载它,从而减少写入开销。这是可行的,因为所有进程都从相同的参数开始,并且梯度在反向传播中同步,因此优化器应保持将参数设置为相同的值。如果您使用此优化(即,在一个进程上保存但在所有进程上恢复),请确保在保存完成之前没有进程开始加载。此外,在加载模块时,您需要提供适当的 map_location
参数,以防止进程进入其他进程的设备。如果缺少 map_location
,torch.load
将首先将模块加载到 CPU,然后将每个参数复制到其保存的位置,这将导致同一台机器上的所有进程使用同一组设备。有关更高级的故障恢复和弹性支持,请参阅 TorchElastic。
def demo_checkpoint(rank, world_size):
print(f"Running DDP checkpoint example on rank {rank}.")
setup(rank, world_size)
model = ToyModel().to(rank)
ddp_model = DDP(model, device_ids=[rank])
CHECKPOINT_PATH = tempfile.gettempdir() + "/model.checkpoint"
if rank == 0:
# All processes should see same parameters as they all start from same
# random parameters and gradients are synchronized in backward passes.
# Therefore, saving it in one process is sufficient.
torch.save(ddp_model.state_dict(), CHECKPOINT_PATH)
# Use a barrier() to make sure that process 1 loads the model after process
# 0 saves it.
dist.barrier()
# configure map_location properly
map_location = {'cuda:%d' % 0: 'cuda:%d' % rank}
ddp_model.load_state_dict(
torch.load(CHECKPOINT_PATH, map_location=map_location, weights_only=True))
loss_fn = nn.MSELoss()
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
optimizer.zero_grad()
outputs = ddp_model(torch.randn(20, 10))
labels = torch.randn(20, 5).to(rank)
loss_fn(outputs, labels).backward()
optimizer.step()
# Not necessary to use a dist.barrier() to guard the file deletion below
# as the AllReduce ops in the backward pass of DDP already served as
# a synchronization.
if rank == 0:
os.remove(CHECKPOINT_PATH)
cleanup()
print(f"Finished running DDP checkpoint example on rank {rank}.")
将 DDP 与模型并行相结合¶
DDP 也适用于多 GPU 模型。当使用大量数据训练大型模型时,DDP 包装多 GPU 模型特别有用。
class ToyMpModel(nn.Module):
def __init__(self, dev0, dev1):
super(ToyMpModel, self).__init__()
self.dev0 = dev0
self.dev1 = dev1
self.net1 = torch.nn.Linear(10, 10).to(dev0)
self.relu = torch.nn.ReLU()
self.net2 = torch.nn.Linear(10, 5).to(dev1)
def forward(self, x):
x = x.to(self.dev0)
x = self.relu(self.net1(x))
x = x.to(self.dev1)
return self.net2(x)
当将多 GPU 模型传递给 DDP 时,不得设置 device_ids
和 output_device
。输入和输出数据将由应用程序或模型 forward()
方法放置在适当的设备中。
def demo_model_parallel(rank, world_size):
print(f"Running DDP with model parallel example on rank {rank}.")
setup(rank, world_size)
# setup mp_model and devices for this process
dev0 = rank * 2
dev1 = rank * 2 + 1
mp_model = ToyMpModel(dev0, dev1)
ddp_mp_model = DDP(mp_model)
loss_fn = nn.MSELoss()
optimizer = optim.SGD(ddp_mp_model.parameters(), lr=0.001)
optimizer.zero_grad()
# outputs will be on dev1
outputs = ddp_mp_model(torch.randn(20, 10))
labels = torch.randn(20, 5).to(dev1)
loss_fn(outputs, labels).backward()
optimizer.step()
cleanup()
print(f"Finished running DDP with model parallel example on rank {rank}.")
if __name__ == "__main__":
n_gpus = torch.cuda.device_count()
assert n_gpus >= 2, f"Requires at least 2 GPUs to run, but got {n_gpus}"
world_size = n_gpus
run_demo(demo_basic, world_size)
run_demo(demo_checkpoint, world_size)
world_size = n_gpus//2
run_demo(demo_model_parallel, world_size)
使用 torch.distributed.run/torchrun 初始化 DDP¶
我们可以利用 PyTorch Elastic 来简化 DDP 代码并更轻松地初始化作业。让我们仍然使用 Toymodel 示例并创建一个名为 elastic_ddp.py
的文件。
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP
class ToyModel(nn.Module):
def __init__(self):
super(ToyModel, self).__init__()
self.net1 = nn.Linear(10, 10)
self.relu = nn.ReLU()
self.net2 = nn.Linear(10, 5)
def forward(self, x):
return self.net2(self.relu(self.net1(x)))
def demo_basic():
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
dist.init_process_group("nccl")
rank = dist.get_rank()
print(f"Start running basic DDP example on rank {rank}.")
# create model and move it to GPU with id rank
device_id = rank % torch.cuda.device_count()
model = ToyModel().to(device_id)
ddp_model = DDP(model, device_ids=[device_id])
loss_fn = nn.MSELoss()
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
optimizer.zero_grad()
outputs = ddp_model(torch.randn(20, 10))
labels = torch.randn(20, 5).to(device_id)
loss_fn(outputs, labels).backward()
optimizer.step()
dist.destroy_process_group()
print(f"Finished running basic DDP example on rank {rank}.")
if __name__ == "__main__":
demo_basic()
然后,可以在所有节点上运行 torch elastic/torchrun 命令,以初始化上面创建的 DDP 作业
torchrun --nnodes=2 --nproc_per_node=8 --rdzv_id=100 --rdzv_backend=c10d --rdzv_endpoint=$MASTER_ADDR:29400 elastic_ddp.py
在上面的示例中,我们在两台主机上运行 DDP 脚本,并且我们在每台主机上运行 8 个进程。也就是说,我们在 16 个 GPU 上运行此作业。请注意,$MASTER_ADDR
在所有节点上必须相同。
在这里,torchrun
将启动 8 个进程,并在其启动的节点上的每个进程上调用 elastic_ddp.py
,但用户还需要应用集群管理工具(如 slurm)来实际在 2 个节点上运行此命令。
例如,在启用 SLURM 的集群上,我们可以编写一个脚本来运行上面的命令并将 MASTER_ADDR
设置为
export MASTER_ADDR=$(scontrol show hostname ${SLURM_NODELIST} | head -n 1)
然后我们可以使用 SLURM 命令运行此脚本:srun --nodes=2 ./torchrun_script.sh
。
这只是一个示例;您可以选择自己的集群调度工具来启动 torchrun
作业。
有关 Elastic run 的更多信息,请参阅快速入门文档。