• 教程 >
  • 使用 Join 上下文管理器进行不等量输入的分布式训练
快捷方式

使用 Join 上下文管理器进行不等量输入的分布式训练

创建日期:2021 年 8 月 4 日 | 最后更新:2023 年 1 月 9 日 | 最后验证:2024 年 11 月 5 日

作者: Andrew Gu

注意

editgithub 中查看和编辑本教程。

注意

Join 在 PyTorch 1.10 中作为原型功能引入。此 API 可能随时更改。

在本教程中,您将了解

  • Join 上下文管理器的概览。

  • 如何将上下文管理器与 DistributedDataParallel 一起使用的示例。

  • 如何将上下文管理器与 DistributedDataParallelZeroRedundancyOptimizer 一起使用的示例。

  • 将关键字参数传递给上下文管理器的示例。

  • 深入了解 Join 上下管理器的工作原理。

  • 演示如何使一个玩具类与上下文管理器兼容的示例。

什么是 Join

分布式数据并行入门 - 基本用例 中,您了解了使用 DistributedDataParallel 执行数据并行训练的基本框架。这会在每个反向传播过程中隐式安排 all-reduce 操作,以同步所有 rank 上的梯度。此类 集体通信 需要进程组中所有 rank 的参与,因此如果某个 rank 输入较少,其他 rank 将会挂起或出错(取决于后端)。更普遍地,对于任何执行逐迭代同步集体通信的类,都会存在此问题。

Join 是一个上下文管理器,用于封装每个 rank 的训练循环,以促进处理不等量输入的训练。该上下文管理器允许提前耗尽输入的 rank(即提前 join 的 rank)模拟尚未 join 的 rank 所执行的集体通信。模拟通信的方式由 hook 指定。

JoinDistributedDataParallel 一起使用

PyTorch 的 DistributedDataParallel 可以直接与 Join 上下文管理器一起使用。以下是一个示例用法

import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.distributed.algorithms.join import Join
from torch.nn.parallel import DistributedDataParallel as DDP

BACKEND = "nccl"
WORLD_SIZE = 2
NUM_INPUTS = 5

def worker(rank):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group(BACKEND, rank=rank, world_size=WORLD_SIZE)

    model = DDP(torch.nn.Linear(1, 1).to(rank), device_ids=[rank])
    # Rank 1 gets one more input than rank 0
    inputs = [torch.tensor([1]).float() for _ in range(NUM_INPUTS + rank)]

    num_inputs = 0
    with Join([model]):
        for input in inputs:
            num_inputs += 1
            loss = model(input).sum()
            loss.backward()

    print(f"Rank {rank} has exhausted all {num_inputs} of its inputs!")

def main():
    mp.spawn(worker, nprocs=WORLD_SIZE, join=True)

if __name__ == "__main__":
    main()

这将产生以下输出(其中 rank 0 和 rank 1 的 print() 输出顺序可能不固定)

Rank 0 has exhausted all 5 of its inputs!
Rank 1 has exhausted all 6 of its inputs!

注意

在引入这个通用 Join 上下文管理器之前,DistributedDataParallel 提供了自己的 join() 上下文管理器。在上面的示例中,使用 with Join([model]): 等同于使用 with model.join():。现有的 DistributedDataParallel.join() 的一个限制是它不允许有多个参与类,例如 DistributedDataParallelZeroRedundancyOptimizer 同时使用。

JoinDistributedDataParallelZeroRedundancyOptimizer 一起使用

Join 上下文管理器不仅适用于单个类,也适用于多个类。PyTorch 的 ZeroRedundancyOptimizer 也与此上下文管理器兼容,因此在此,我们探讨如何修改前面的示例,使其同时使用 DistributedDataParallelZeroRedundancyOptimizer

from torch.distributed.optim import ZeroRedundancyOptimizer as ZeRO
from torch.optim import Adam

def worker(rank):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group(BACKEND, rank=rank, world_size=WORLD_SIZE)

    model = DDP(torch.nn.Linear(1, 1).to(rank), device_ids=[rank])
    optim = ZeRO(model.parameters(), Adam, lr=0.01)
    # Rank 1 gets one more input than rank 0
    inputs = [torch.tensor([1]).float() for _ in range(NUM_INPUTS + rank)]

    num_inputs = 0
    # Pass both `model` and `optim` into `Join()`
    with Join([model, optim]):
        for input in inputs:
            num_inputs += 1
            loss = model(input).sum()
            loss.backward()
            optim.step()

    print(f"Rank {rank} has exhausted all {num_inputs} of its inputs!")

这将产生与之前相同的输出。值得注意的改变是额外将 ZeroRedundancyOptimizer 实例传递给了 Join()

传递关键字参数

类可以提供关键字参数,以在运行时修改其在上下文管理器中的行为。例如,DistributedDataParallel 提供了一个参数 divide_by_initial_world_size,它决定了梯度是除以初始 world size 还是有效 world size(即未 join 的 rank 数量)。此类关键字参数可以直接传递给上下文管理器。

with Join([model, optim], divide_by_initial_world_size=False):
    for input in inputs:
        ...

警告

传递给上下文管理器的关键字参数在所有参与类之间共享。这应该不是一个限制,因为我们不期望出现多个 Joinable 需要对同一参数进行不同设置的情况。尽管如此,这仍是一个需要记住的点。

Join 如何工作?

既然我们已经看到了一些使用 Join 上下文管理器的初步示例,现在让我们深入了解它的工作原理。这将为您提供对其全部功能的更深入的理解,并帮助您使自己的自定义类兼容。在这里,我们将介绍 Join 类以及支持类 JoinableJoinHook

Joinable

首先,与 Join 上下文管理器兼容的类必须继承自抽象基类 Joinable。具体来说,一个 Joinable 必须实现

  • join_hook(self, **kwargs) -> JoinHook

此方法返回该 JoinableJoinHook 实例,它决定了已 join 的进程如何模拟该 Joinable 执行的逐迭代集体通信。

  • join_device(self) -> torch.device

此方法返回 Join 上下文管理器用于执行集体通信的设备,例如 torch.device("cuda:0")torch.device("cpu")

  • join_process_group(self) -> ProcessGroup

此方法返回 Join 上下文管理器用于执行集体通信的进程组。

具体来说,join_devicejoin_process_group 是必需的属性,以确保上下文管理器可以安排已 join 和未 join 进程之间的集体通信。其中一个用途是使用 all-reduce 在每次迭代中计算未 join 的进程数量。另一个用途是实现 throw_on_early_termination=True 所需的机制,我们将在下文解释这一点。

DistributedDataParallelZeroRedundancyOptimizer 已经继承自 Joinable 并实现了上述方法,这就是为什么我们可以在之前的示例中直接使用它们。

Joinable 类应确保调用 Joinable 构造函数,因为它会初始化一个 JoinConfig 实例,该实例被上下文管理器内部使用以确保正确性。这将被保存在每个 Joinable 中作为一个字段 _join_config

JoinHook

接下来,让我们分解 JoinHook 类。一个 JoinHook 为上下文管理器提供了两个入口点

  • main_hook(self) -> None

当存在尚未 join 的 rank 时,每个已 join 的 rank 会重复调用此 hook。它旨在模拟 Joinable 在每次训练迭代中(例如在一个前向传播、反向传播和优化器步骤中)执行的集体通信。

  • post_hook(self, is_last_joiner: bool) -> None

当所有 rank 都已 join 后,此 hook 会被调用一次。它会传递一个额外的布尔参数 is_last_joiner,表示该 rank 是否是最后 join 的 rank 之一。此参数可能对同步有用。

为了提供这些 hook 具体示例,提供的 ZeroRedundancyOptimizer 的 main hook 会正常执行优化器步骤,因为已 join 的 rank 仍负责更新和同步其参数分片;而提供的 DistributedDataParallel 的 post-hook 会从最后一个 join 的 rank 之一广播最终更新的模型,以确保所有 rank 上的模型一致。

Join

最后,让我们看看这些是如何融入 Join 类本身的。

  • __init__(self, joinables: List[Joinable], enable: bool = True, throw_on_early_termination: bool = False)

正如我们在前面的示例中看到的,构造函数接受一个参与训练循环的 Joinable 列表。这些应该是每个迭代中执行集体通信的类。

enable 是一个布尔值,如果你知道不会有不等量输入,可以将其设置为 False,在这种情况下,上下文管理器将变得空洞,类似于 contextlib.nullcontext()。这还可能禁用参与的 Joinable 中的 join 相关计算。

throw_on_early_termination 是一个布尔值,可以设置为 True,以便在检测到不等量输入的那一刻让每个 rank 抛出异常。这对于不符合上下文管理器要求的场景非常有用,最典型的是当不同类之间的集体通信可能任意交错时,例如将 DistributedDataParallel 与包含 SyncBatchNorm 层的模型一起使用。在这种情况下,应将此参数设置为 True,以便应用逻辑可以捕获异常并确定如何处理。

  • 核心逻辑发生在 __exit__() 方法中,该方法在存在未 join 的 rank 时循环,调用每个 Joinable 的 main hook,然后一旦所有 rank 都已 join,调用它们的 post hook。main hook 和 post hook 都按照 Joinable 传入的顺序进行迭代。

  • 上下文管理器需要未 join 进程的心跳。因此,每个 Joinable 类应在其逐迭代集体通信之前调用 Join.notify_join_context()。上下文管理器将确保只有传入的第一个 Joinable 实际发送心跳。

警告

如上文关于 throw_on_early_termination 所述,Join 上下文管理器与某些类的组合不兼容。JoinableJoinHook 必须是可序列化的,因为每个 hook 都必须完全执行后才能进行下一个。换句话说,两个 hook 不能重叠。此外,目前 main hook 和 post hook 都按照相同的确定性顺序进行迭代。如果这看起来是一个主要限制,我们可能会修改 API 以允许自定义排序。

使一个玩具类与 Join 一起工作

由于上一节介绍了一些概念,现在让我们通过一个玩具示例来实际应用它们。在这里,我们将实现一个类,该类在某个 rank join 之前计算所有 rank 上看到的输入总数。这应该能让您初步了解如何使自己的类与 Join 上下文管理器兼容。

具体来说,以下代码会让每个 rank 打印出 (1) 在该 rank join 之前所有 rank 上看到的输入数量总计,以及 (2) 所有 rank 上的总输入数量。

import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.distributed.algorithms.join import Join, Joinable, JoinHook

BACKEND = "nccl"
WORLD_SIZE = 2
NUM_INPUTS = 5

class CounterJoinHook(JoinHook):
    r"""
    Join hook for :class:`Counter`.

    Arguments:
        counter (Counter): the :class:`Counter` object using this hook.
        sync_max_count (bool): whether to sync the max count once all ranks
            join.
    """
    def __init__(
        self,
        counter,
        sync_max_count
    ):
        self.counter = counter
        self.sync_max_count = sync_max_count

    def main_hook(self):
        r"""
        Shadows the counter's all-reduce by all-reducing a dim-1 zero tensor.
        """
        t = torch.zeros(1, device=self.counter.device)
        dist.all_reduce(t)

    def post_hook(self, is_last_joiner: bool):
        r"""
        Synchronizes the max count across all :class:`Counter` s if
        ``sync_max_count=True``.
        """
        if not self.sync_max_count:
            return
        rank = dist.get_rank(self.counter.process_group)
        common_rank = self.counter.find_common_rank(rank, is_last_joiner)
        if rank == common_rank:
            self.counter.max_count = self.counter.count.detach().clone()
        dist.broadcast(self.counter.max_count, src=common_rank)

class Counter(Joinable):
    r"""
    Example :class:`Joinable` that counts the number of training iterations
    that it participates in.
    """
    def __init__(self, device, process_group):
        super(Counter, self).__init__()
        self.device = device
        self.process_group = process_group
        self.count = torch.tensor([0], device=device).float()
        self.max_count = torch.tensor([0], device=device).float()

    def __call__(self):
        r"""
        Counts the number of inputs processed on this iteration by all ranks
        by all-reducing a dim-1 one tensor; increments its own internal count.
        """
        Join.notify_join_context(self)
        t = torch.ones(1, device=self.device).float()
        dist.all_reduce(t)
        self.count += t

    def join_hook(self, **kwargs) -> JoinHook:
        r"""
        Return a join hook that shadows the all-reduce in :meth:`__call__`.

        This join hook supports the following keyword arguments:
            sync_max_count (bool, optional): whether to synchronize the maximum
                count across all ranks once all ranks join; default is ``False``.
        """
        sync_max_count = kwargs.get("sync_max_count", False)
        return CounterJoinHook(self, sync_max_count)

    @property
    def join_device(self) -> torch.device:
        return self.device

    @property
    def join_process_group(self):
        return self.process_group

    def find_common_rank(self, rank, to_consider):
        r"""
        Returns the max rank of the ones to consider over the process group.
        """
        common_rank = torch.tensor([rank if to_consider else -1], device=self.device)
        dist.all_reduce(common_rank, op=dist.ReduceOp.MAX, group=self.process_group)
        common_rank = common_rank.item()
        return common_rank

def worker(rank):
    assert torch.cuda.device_count() >= WORLD_SIZE
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group(BACKEND, rank=rank, world_size=WORLD_SIZE)

    counter = Counter(torch.device(f"cuda:{rank}"), dist.group.WORLD)
    inputs = [torch.tensor([1]).float() for _ in range(NUM_INPUTS + rank)]

    with Join([counter], sync_max_count=True):
        for _ in inputs:
            counter()

    print(f"{int(counter.count.item())} inputs processed before rank {rank} joined!")
    print(f"{int(counter.max_count.item())} inputs processed across all ranks!")

def main():
    mp.spawn(worker, nprocs=WORLD_SIZE, join=True)

if __name__ == "__main__":
    main()

由于 rank 0 看到 5 个输入,rank 1 看到 6 个,这将产生以下输出

10 inputs processed before rank 0 joined!
11 inputs processed across all ranks!
11 inputs processed before rank 1 joined!
11 inputs processed across all ranks!

一些需要强调的关键点

  • 一个 Counter 实例每迭代执行一次 all-reduce,因此 main hook 也执行一次 all-reduce 来模拟它。

  • Counter 类在其 __call__() 方法的开始处调用 Join.notify_join_context(),因为这是在其逐迭代集体通信(即其 all-reduce)之前的位置。

  • is_last_joiner 参数用于确定 post-hook 中的广播源。

  • 我们将 sync_max_count 关键字参数传递给上下文管理器,该参数随后被转发给 Counter 的 join hook。

文档

访问 PyTorch 的完整开发者文档

查看文档

教程

获取针对初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得问题解答

查看资源