快捷方式

DistributedDataParallel

class torch.nn.parallel.DistributedDataParallel(module, device_ids=None, output_device=None, dim=0, broadcast_buffers=True, init_sync=True, process_group=None, bucket_cap_mb=None, find_unused_parameters=False, check_reduction=False, gradient_as_bucket_view=False, static_graph=False, delay_all_reduce_named_params=None, param_to_hook_all_reduce=None, mixed_precision=None, device_mesh=None)[source][source]

在模块级别实现基于 torch.distributed 的分布式数据并行。

此容器通过跨每个模型副本同步梯度来提供数据并行性。同步设备由输入的 process_group 指定,默认情况下是整个世界。请注意,DistributedDataParallel 不会分块或以其他方式跨参与的 GPU 分片输入;用户负责定义如何做到这一点,例如通过使用 DistributedSampler

另请参阅:基础知识使用 nn.parallel.DistributedDataParallel 代替多进程或 nn.DataParalleltorch.nn.DataParallel 中相同的输入限制也适用。

创建此类需要已经初始化 torch.distributed,通过调用 torch.distributed.init_process_group()

对于单节点多 GPU 数据并行训练,DistributedDataParallel 经验证比 torch.nn.DataParallel 快得多。

要在具有 N 个 GPU 的主机上使用 DistributedDataParallel,应启动 N 个进程,确保每个进程专门在从 0 到 N-1 的单个 GPU 上工作。这可以通过为每个进程设置 CUDA_VISIBLE_DEVICES 或调用以下方法来完成

>>> torch.cuda.set_device(i)

其中 i 的范围是 0 到 N-1。在每个进程中,您应参照以下内容来构建此模块

>>> torch.distributed.init_process_group(
>>>     backend='nccl', world_size=N, init_method='...'
>>> )
>>> model = DistributedDataParallel(model, device_ids=[i], output_device=i)

为了在每个节点上启动多个进程,可以使用 torch.distributed.launchtorch.multiprocessing.spawn

注意

请参阅 PyTorch 分布式概述,了解与分布式训练相关的所有特性的简要介绍。

注意

DistributedDataParallel 可以与 torch.distributed.optim.ZeroRedundancyOptimizer 结合使用,以减少每个 rank 优化器状态的内存占用。请参阅 ZeroRedundancyOptimizer 代码示例 获取更多详细信息。

注意

使用 GPU 时,nccl 后端是目前最快且强烈推荐的后端。这适用于单节点和多节点分布式训练。

注意

此模块还支持混合精度分布式训练。这意味着您的模型可以具有不同类型的参数,例如 fp16fp32 的混合类型,对这些混合类型参数的梯度归约将正常工作。

注意

如果在某个进程上使用 torch.save 检查点模块,并在其他一些进程上使用 torch.load 恢复它,请确保为每个进程正确配置了 map_location。如果没有 map_locationtorch.load 会将模块恢复到保存时的设备。

注意

当模型在 M 个节点上以 batch=N 进行训练时,如果损失在批次的实例之间求和(而不是像通常那样求平均值)(因为不同节点之间的梯度是平均的),则与在单个节点上以 batch=M*N 训练相同模型相比,梯度将小 M 倍。当您想要获得与本地训练对应的数学等效训练过程时,应考虑这一点。但在大多数情况下,您可以将 DistributedDataParallel 包装的模型、DataParallel 包装的模型和单个 GPU 上的普通模型视为相同(例如,对于等效批量大小使用相同的学习率)。

注意

参数从不在进程之间广播。该模块对梯度执行 all-reduce 步骤,并假定优化器将在所有进程中以相同方式修改它们。缓冲区(例如 BatchNorm 统计信息)在每次迭代中从 rank 0 进程中的模块广播到系统中的所有其他副本。

注意

如果将 DistributedDataParallel 与 分布式 RPC 框架 结合使用,应始终使用 torch.distributed.autograd.backward() 计算梯度,并使用 torch.distributed.optim.DistributedOptimizer 优化参数。

示例

>>> import torch.distributed.autograd as dist_autograd
>>> from torch.nn.parallel import DistributedDataParallel as DDP
>>> import torch
>>> from torch import optim
>>> from torch.distributed.optim import DistributedOptimizer
>>> import torch.distributed.rpc as rpc
>>> from torch.distributed.rpc import RRef
>>>
>>> t1 = torch.rand((3, 3), requires_grad=True)
>>> t2 = torch.rand((3, 3), requires_grad=True)
>>> rref = rpc.remote("worker1", torch.add, args=(t1, t2))
>>> ddp_model = DDP(my_model)
>>>
>>> # Setup optimizer
>>> optimizer_params = [rref]
>>> for param in ddp_model.parameters():
>>>     optimizer_params.append(RRef(param))
>>>
>>> dist_optim = DistributedOptimizer(
>>>     optim.SGD,
>>>     optimizer_params,
>>>     lr=0.05,
>>> )
>>>
>>> with dist_autograd.context() as context_id:
>>>     pred = ddp_model(rref.to_here())
>>>     loss = loss_func(pred, target)
>>>     dist_autograd.backward(context_id, [loss])
>>>     dist_optim.step(context_id)

注意

DistributedDataParallel 目前对使用 torch.utils.checkpoint() 进行梯度检查点提供有限支持。如果使用 use_reentrant=False(推荐)完成检查点,DDP 将按预期工作,没有任何限制。但是,如果使用 use_reentrant=True(默认值)完成检查点,当模型中没有未使用的参数并且每个层最多检查点一次时,DDP 将按预期工作(请确保您没有将 find_unused_parameters=True 传递给 DDP)。我们目前不支持层多次检查点的情况,或检查点模型中存在未使用参数的情况。

注意

为了让非 DDP 模型加载 DDP 模型的状态字典,需要在加载之前应用 consume_prefix_in_state_dict_if_present() 以剥离 DDP 状态字典中的前缀“module.”。

警告

构造函数、forward 方法以及输出(或此模块输出的函数)的求导是分布式同步点。如果不同进程可能正在执行不同的代码,请务必考虑这一点。

警告

此模块假定在创建时所有参数都已在模型中注册。之后不应添加或删除任何参数。缓冲区也一样。

警告

此模块假定每个分布式进程的模型中注册的所有参数顺序相同。模块本身将按照模型注册参数的相反顺序执行梯度 allreduce。换句话说,用户有责任确保每个分布式进程具有完全相同的模型,从而具有完全相同的参数注册顺序。

警告

此模块允许使用非行主连续步长的参数。例如,您的模型可能包含一些 torch.memory_formattorch.contiguous_format 的参数,以及其他格式为 torch.channels_last 的参数。但是,不同进程中对应的参数必须具有相同的步长。

警告

此模块不适用于 torch.autograd.grad() (即,仅当梯度累积在参数的 .grad 属性中时才有效)。

警告

如果您打算将此模块与 nccl 后端或 gloo 后端(使用 Infiniband)以及使用多个 worker 的 DataLoader 一起使用,请将多进程启动方法更改为 forkserver(仅限 Python 3)或 spawn。遗憾的是,Gloo(使用 Infiniband)和 NCCL2 不是 fork 安全的,如果不更改此设置,可能会遇到死锁。

警告

使用 DistributedDataParallel 包装模型后,不应尝试更改模型的参数。因为,当使用 DistributedDataParallel 包装模型时,DistributedDataParallel 的构造函数将在构建时在模型本身的所有参数上注册额外的梯度归约函数。如果之后更改模型的参数,梯度归约函数将不再与正确的参数集匹配。

警告

DistributedDataParallel分布式 RPC 框架 结合使用是实验性的,可能会发生变化。

参数
  • module (Module) – 要并行的模块

  • device_ids (list of int or torch.device) –

    CUDA 设备。1) 对于单设备模块,device_ids 可以只包含一个设备 ID,表示此进程对应的输入模块所在的唯一 CUDA 设备。此外,device_ids 也可以是 None。2) 对于多设备模块和 CPU 模块,device_ids 必须是 None

    在两种情况下,如果 device_idsNone,则前向传播的输入数据和实际模块都必须放置在正确的设备上。(默认值:None

  • output_device (int or torch.device) – 单设备 CUDA 模块的输出设备位置。对于多设备模块和 CPU 模块,它必须为 None,并且模块本身决定输出位置。(默认值:单设备模块为 device_ids[0]

  • broadcast_buffers (bool) – 在 forward 函数开始时启用模块缓冲区同步(广播)的标志。(默认值:True

  • init_sync (bool) – 初始化期间是否同步以验证参数形状并广播参数和缓冲区。警告:如果将其设置为 False,则用户需要自行确保所有 rank 上的权重相同。(默认值:True

  • process_group – 用于分布式数据 all-reduction 的进程组。如果为 None,则将使用默认进程组,该进程组由 torch.distributed.init_process_group() 创建。(默认值:None

  • bucket_cap_mbDistributedDataParallel 会将参数分桶到多个桶中,以便每个桶的梯度归约可以潜在地与反向计算重叠。bucket_cap_mb 控制桶的大小,单位为 MiB。如果为 None,将使用默认大小 25 MiB。(默认值:None

  • find_unused_parameters (bool) – 遍历从包装模块的 forward 函数返回值中包含的所有张量开始的自动梯度图。作为此图一部分而未接收到梯度的参数会预先标记为准备好进行归约。此外,可能已在包装模块的 forward 函数中使用但未参与损失计算因此也不会接收到梯度的参数也会预先标记为准备好进行归约。(默认值:False

  • check_reduction – 此参数已弃用。

  • gradient_as_bucket_view (bool) – 当设置为 True 时,梯度将是指向 allreduce 通信桶不同偏移量的视图。这可以减少峰值内存使用,节省的内存大小将等于总梯度大小。此外,它避免了在梯度和 allreduce 通信桶之间进行复制的开销。当梯度是视图时,不能对梯度调用 detach_()。如果遇到此类错误,请参考 torch/optim/optimizer.py 中的 zero_grad() 函数来解决。请注意,梯度在第一次迭代后将是视图,因此峰值内存节省应在第一次迭代后检查。

  • static_graph (bool) –

    当设置为 True 时,DDP 知道训练图是静态的。静态图意味着 1) 整个训练循环中,已使用和未使用参数的集合不会改变;在这种情况下,用户是否设置 find_unused_parameters = True 并不重要。2) 整个训练循环中,图的训练方式不会改变(这意味着没有依赖于迭代的控制流)。当 static_graph 设置为 True 时,DDP 将支持过去无法支持的情况:1) 可重入反向传播。2) 激活检查点多次。3) 模型具有未使用参数时的激活检查点。4) 存在位于 forward 函数之外的模型参数。5) 当存在未使用参数时,可能会提高性能,因为当 static_graph 设置为 True 时,DDP 不会在每次迭代中搜索图来检测未使用参数。要检查是否可以将 static_graph 设置为 True,一种方法是在上一次模型训练结束时检查 ddp 日志数据,如果 ddp_logging_data.get("can_set_static_graph") == True,则很可能也可以设置 static_graph = True

    示例:
    >>> model_DDP = torch.nn.parallel.DistributedDataParallel(model)
    >>> # Training loop
    >>> ...
    >>> ddp_logging_data = model_DDP._get_ddp_logging_data()
    >>> static_graph = ddp_logging_data.get("can_set_static_graph")
    

  • delay_all_reduce_named_params (list of tuple of str and torch.nn.Parameter) – 当 param_to_hook_all_reduce 中指定的参数梯度就绪时,其 all reduce 将被延迟的命名参数列表。DDP 的其他参数不适用于此参数中指定的命名参数,因为这些命名参数将被 DDP reducer 忽略。

  • param_to_hook_all_reduce (torch.nn.Parameter) – 用于钩住 delay_all_reduce_named_params 中指定的参数延迟 all reduce 的参数。

变量

module (Module) – 要并行的模块。

示例

>>> torch.distributed.init_process_group(backend='nccl', world_size=4, init_method='...')
>>> net = torch.nn.parallel.DistributedDataParallel(model)
join(divide_by_initial_world_size=True, enable=True, throw_on_early_termination=False)[source][source]

用于在 DDP 中处理跨进程不均匀输入的训练的上下文管理器。

此上下文管理器将跟踪已加入的 DDP 进程,并通过插入集体通信操作来“影子”前向和反向传播,以匹配未加入的 DDP 进程创建的操作。这将确保每个集体调用都有已加入 DDP 进程的相应调用,从而防止在使用跨进程不均匀输入进行训练时可能发生的挂起或错误。此外,如果将标志 throw_on_early_termination 指定为 True,则一旦某个 rank 用尽输入,所有 trainer 都将抛出错误,从而允许根据应用逻辑捕获和处理这些错误。

一旦所有 DDP 进程都已加入,上下文管理器会将与最后一个加入进程对应的模型广播到所有进程,以确保模型在所有进程中相同(这是 DDP 保证的)。

要使用此功能启用跨进程不均匀输入的训练,只需将此上下文管理器包装在训练循环周围即可。无需对模型或数据加载进行进一步修改。

警告

如果此上下文管理器所包装的模型或训练循环具有额外的分布式集体操作,例如模型前向传播中的 SyncBatchNorm,则必须启用标志 throw_on_early_termination。这是因为此上下文管理器不知道非 DDP 的集体通信。当任何一个 rank 耗尽输入时,此标志将导致所有 rank 抛出错误,从而允许在所有 rank 中捕获并从中恢复这些错误。

参数
  • divide_by_initial_world_size (bool) – 如果为 True,将用启动 DDP 训练时的初始 world_size 除以梯度。如果为 False,将在 allreduce 期间计算有效的 world size(尚未耗尽输入的 rank 数量)并用此值除以梯度。将 divide_by_initial_world_size=True 以确保包括不均匀输入在内的每个输入样本在对全局梯度的贡献方面具有相同的权重。这是通过即使遇到不均匀输入也始终将梯度除以初始 world_size 来实现的。如果将此设置为 False,我们将梯度除以剩余的节点数。这确保了与在较小的 world_size 上训练的对等性,尽管这也意味着不均匀输入将对全局梯度贡献更多。通常,对于训练任务的最后几个输入不均匀的情况,您希望将其设置为 True。在极端情况下,当输入数量存在较大差异时,将其设置为 False 可能会提供更好的结果。

  • enable (bool) – 是否启用不均匀输入检测。如果您知道输入在参与进程中是均匀的,可以传入 enable=False 来禁用。默认值为 True

  • throw_on_early_termination (bool) – 当至少一个 rank 耗尽输入时,是抛出错误还是继续训练。如果为 True,则第一个 rank 到达数据末尾时将抛出错误。如果为 False,则将继续以较小的有效 world size 进行训练,直到所有 rank 都已加入。请注意,如果指定此标志,则将忽略标志 divide_by_initial_world_size。默认值为 False

示例

>>> import torch
>>> import torch.distributed as dist
>>> import os
>>> import torch.multiprocessing as mp
>>> import torch.nn as nn
>>> # On each spawned worker
>>> def worker(rank):
>>>     dist.init_process_group("nccl", rank=rank, world_size=2)
>>>     torch.cuda.set_device(rank)
>>>     model = nn.Linear(1, 1, bias=False).to(rank)
>>>     model = torch.nn.parallel.DistributedDataParallel(
>>>         model, device_ids=[rank], output_device=rank
>>>     )
>>>     # Rank 1 gets one more input than rank 0.
>>>     inputs = [torch.tensor([1]).float() for _ in range(10 + rank)]
>>>     with model.join():
>>>         for _ in range(5):
>>>             for inp in inputs:
>>>                 loss = model(inp).sum()
>>>                 loss.backward()
>>>     # Without the join() API, the below synchronization will hang
>>>     # blocking for rank 1's allreduce to complete.
>>>     torch.cuda.synchronize(device=rank)
join_hook(**kwargs)[source][source]

DDP join hook 通过在前向和反向传播中镜像通信,支持对不均匀输入进行训练。

参数

kwargs (字典) – 一个dict,包含在运行时修改 join hook 行为的任何关键字参数;所有共享同一 join 上下文管理器的 Joinable 实例都将接收到相同的 kwargs 值。

该 hook 支持以下关键字参数
divide_by_initial_world_size (布尔值, 可选)

如果 True,则梯度除以 DDP 启动时的初始 world size。如果 False,则梯度除以有效的 world size(即未 join 的进程数),这意味着不均匀的输入对全局梯度贡献更大。通常,如果输入不均匀程度很小,应将其设置为 True,但在极端情况下,可以设置为 False 以可能获得更好的结果。默认值为 True

no_sync()[源][源]

用于禁用 DDP 进程之间梯度同步的上下文管理器。

在此上下文中,梯度将累积到模块变量上,这些变量稍后将在退出上下文的第一次前向-后向传播(forward-backward pass)中进行同步。

示例

>>> ddp = torch.nn.parallel.DistributedDataParallel(model, pg)
>>> with ddp.no_sync():
>>>     for input in inputs:
>>>         ddp(input).backward()  # no synchronization, accumulate grads
>>> ddp(another_input).backward()  # synchronize grads

警告

前向传播(forward pass)应该包含在上下文管理器内部,否则梯度仍将被同步。

register_comm_hook(state, hook)[源][源]

注册通信 hook,用于跨多个 worker 对梯度进行用户自定义的 DDP 聚合。

这个 hook 对于研究人员尝试新想法非常有用。例如,可以使用此 hook 实现多种算法,如 GossipGrad 和梯度压缩,这些算法在运行分布式数据并行(Distributed DataParallel)训练时涉及不同的参数同步通信策略。

参数
  • state (对象) –

    传递给 hook,用于在训练过程中维护任何状态信息。示例包括梯度压缩中的错误反馈、GossipGrad 中接下来要通信的对等节点等。

    它由每个 worker 本地存储,并由该 worker 上的所有梯度张量(gradient tensor)共享。

  • hook (可调用对象) –

    可调用对象,具有以下签名:hook(state: object, bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]

    一旦 bucket 准备就绪,就会调用此函数。hook 可以执行任何所需的处理,并返回一个 Future,表示任何异步工作(例如 allreduce)的完成。如果 hook 不执行任何通信,它仍然必须返回一个已完成的 Future。该 Future 应持有 grad bucket 中张量的新值。一旦 bucket 准备就绪,c10d reducer 将调用此 hook,并使用 Future 返回的张量,并将梯度复制到各个参数。请注意,Future 的返回类型必须是单个张量。

    我们还提供一个名为 get_future 的 API,用于检索与 c10d.ProcessGroup.Work 完成相关的 Future。 get_future 目前支持 NCCL,并且支持 GLOO 和 MPI 上的大多数操作,但点对点(peer to peer)操作(send/recv)除外。

警告

Grad bucket 中的张量不会预先除以 world_size。在执行 allreduce 等操作时,用户负责除以 world_size。

警告

DDP 通信 hook 只能注册一次,并且应在调用 backward 之前注册。

警告

hook 返回的 Future 对象应包含一个与 grad bucket 内张量具有相同形状的单个张量。

警告

get_future API 支持 NCCL,以及部分 GLOO 和 MPI 后端(不支持点对点操作,如 send/recv),并将返回一个 torch.futures.Future

示例:

下面是一个 noop hook 的示例,它返回相同的张量。

>>> def noop(state: object, bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]:
>>>     fut = torch.futures.Future()
>>>     fut.set_result(bucket.buffer())
>>>     return fut
>>> ddp.register_comm_hook(state=None, hook=noop)
示例:

下面是一个并行 SGD 算法的示例,其中梯度在 allreduce 之前被编码,然后在 allreduce 之后被解码。

>>> def encode_and_decode(state: object, bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]:
>>>     encoded_tensor = encode(bucket.buffer())  # encode gradients
>>>     fut = torch.distributed.all_reduce(encoded_tensor).get_future()
>>>     # Define the then callback to decode.
>>>     def decode(fut):
>>>         decoded_tensor = decode(fut.value()[0])  # decode gradients
>>>         return decoded_tensor
>>>     return fut.then(decode)
>>> ddp.register_comm_hook(state=None, hook=encode_and_decode)

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得问题解答

查看资源