• 文档 >
  • 通用 Join 上下文管理器
快捷方式

通用 Join 上下文管理器

通用 join 上下文管理器有助于在不均匀输入上进行分布式训练。本页概述了相关类的 API:JoinJoinableJoinHook。有关教程,请参阅 使用 Join 上下文管理器进行不均匀输入的分布式训练

class torch.distributed.algorithms.Join(joinables, enable=True, throw_on_early_termination=False, **kwargs)[源代码][源代码]

此类定义了通用 join 上下文管理器,允许在进程加入后调用自定义钩子。

这些钩子应该模拟未加入进程的集合通信,以防止挂起和错误,并确保算法正确性。有关钩子定义的详细信息,请参阅 JoinHook

警告

上下文管理器要求每个参与的 Joinable 在其自身的每次迭代集合通信之前调用方法 notify_join_context(),以确保正确性。

警告

上下文管理器要求所有 JoinHook 对象中的 process_group 属性都相同。如果存在多个 JoinHook 对象,则使用第一个的 device。进程组和设备信息用于检查未加入的进程,并在启用 throw_on_early_termination 时通知进程抛出异常,这两者都使用 all-reduce。

参数
  • joinables (List[Joinable]) – 参与的 Joinable 列表;其钩子按给定顺序进行迭代。

  • enable (bool) – 一个标志,用于启用不均匀输入检测;将其设置为 False 会禁用上下文管理器的功能,仅当用户知道输入不会不均匀时才应设置(默认值:True)。

  • throw_on_early_termination (bool) – 一个标志,控制在检测到不均匀输入时是否抛出异常(默认值:False)。

示例

>>> import os
>>> import torch
>>> import torch.distributed as dist
>>> import torch.multiprocessing as mp
>>> import torch.nn.parallel.DistributedDataParallel as DDP
>>> import torch.distributed.optim.ZeroRedundancyOptimizer as ZeRO
>>> from torch.distributed.algorithms.join import Join
>>>
>>> # On each spawned worker
>>> def worker(rank):
>>>     dist.init_process_group("nccl", rank=rank, world_size=2)
>>>     model = DDP(torch.nn.Linear(1, 1).to(rank), device_ids=[rank])
>>>     optim = ZeRO(model.parameters(), torch.optim.Adam, lr=0.01)
>>>     # Rank 1 gets one more input than rank 0
>>>     inputs = [torch.tensor([1.]).to(rank) for _ in range(10 + rank)]
>>>     with Join([model, optim]):
>>>         for input in inputs:
>>>             loss = model(input).sum()
>>>             loss.backward()
>>>             optim.step()
>>>     # All ranks reach here without hanging/erroring
static notify_join_context(joinable)[源代码][源代码]

通知 join 上下文管理器调用进程尚未加入。

然后,如果 throw_on_early_termination=True,则检查是否已检测到不均匀输入(即是否有进程已加入),如果检测到则抛出异常。

此方法应在 Joinable 对象的每次迭代集合通信之前调用。例如,应在 DistributedDataParallel 的前向传播开始时调用此方法。

只有传递给上下文管理器的第一个 Joinable 对象在此方法中执行集合通信,对于其他对象,此方法是空操作。

参数

joinable (Joinable) – 调用此方法的 Joinable 对象。

返回值

如果 joinable 是传递给上下文管理器的第一个对象,则返回用于通知上下文管理器进程尚未加入的 all-reduce 的异步工作句柄;否则返回 None

class torch.distributed.algorithms.Joinable[源代码][源代码]

这定义了一个可加入类的抽象基类。

可加入类(继承自 Joinable)除了应实现返回设备和进程组信息的 join_device()join_process_group() 方法外,还应实现返回 JoinHook 实例的 join_hook() 方法。

abstract property join_device: device

返回执行 join 上下文管理器所需的集合通信的设备。

abstract join_hook(**kwargs)[源代码][源代码]

为给定的 Joinable 返回一个 JoinHook 实例。

参数

kwargs (dict) – 一个 dict,包含用于在运行时修改 join 钩子行为的任意关键字参数;所有共享同一 join 上下文管理器的 Joinable 实例都将收到相同的 kwargs 值。

返回类型

JoinHook

abstract property join_process_group: Any

返回 join 上下文管理器自身所需的集合通信的进程组。

class torch.distributed.algorithms.JoinHook[源代码][源代码]

这定义了一个 join 钩子,它在 join 上下文管理器中提供了两个入口点。

入口点:一个主钩子 (main hook),当存在未加入的进程时会重复调用;以及一个后钩子 (post-hook),在所有进程都加入后调用一次。

要为通用 join 上下文管理器实现 join 钩子,请定义一个继承自 JoinHook 的类,并根据需要重写 main_hook()post_hook() 方法。

main_hook()[源代码][源代码]

当存在未加入的进程时调用此钩子,以模拟训练迭代中的集合通信。

训练迭代,即一次前向传播、一次后向传播和一次优化器步。

post_hook(is_last_joiner)[源代码][源代码]

在所有进程加入后调用钩子。

它会接收一个额外的 bool 参数 is_last_joiner,指示该 rank 是否是最后加入的之一。

参数

is_last_joiner (bool) – 如果该 rank 是最后加入的之一,则为 True;否则为 False

文档

查看 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题解答

查看资源