• 教程 >
  • 使用 Join 上下文管理器进行不均匀输入的分布式训练
快捷方式

使用 Join 上下文管理器进行不均匀输入的分布式训练

作者: Andrew Gu

注意

editgithub 中查看和编辑本教程。

注意

Join 是在 PyTorch 1.10 中引入的原型功能。此 API 可能会发生更改。

在本教程中,您将了解

  • Join 上下文管理器的概述。

  • 使用 DistributedDataParallel 使用上下文管理器的示例。

  • 使用 DistributedDataParallelZeroRedundancyOptimizer 使用上下文管理器的示例。

  • 将关键字参数传递给上下文管理器的示例。

  • 深入了解 Join 上下文管理器的工作原理。

  • 一个示例,展示如何使一个玩具类与上下文管理器兼容。

什么是 Join?

分布式数据并行入门 - 基本用例 中,您看到了使用 DistributedDataParallel 执行数据并行训练的一般框架。这在每次反向传播中隐式地安排所有约简,以跨排名同步梯度。这种 集体通信 需要进程组中所有排名的参与,因此如果一个排名有较少的输入,那么其他排名将挂起或出错(取决于后端)。更一般地,这个问题会持续存在于任何在每次迭代中执行同步集体通信的类中。

Join 是一个上下文管理器,用于围绕您的每个排名训练循环,以促进对不均匀输入的训练。上下文管理器允许过早耗尽输入的排名(即加入过早)来跟踪由尚未加入的排名执行的集体通信。通信被跟踪的方式由钩子指定。

JoinDistributedDataParallel 结合使用

PyTorch 的 DistributedDataParallelJoin 上下文管理器开箱即用。这是一个示例用法

import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.distributed.algorithms.join import Join
from torch.nn.parallel import DistributedDataParallel as DDP

BACKEND = "nccl"
WORLD_SIZE = 2
NUM_INPUTS = 5

def worker(rank):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group(BACKEND, rank=rank, world_size=WORLD_SIZE)

    model = DDP(torch.nn.Linear(1, 1).to(rank), device_ids=[rank])
    # Rank 1 gets one more input than rank 0
    inputs = [torch.tensor([1]).float() for _ in range(NUM_INPUTS + rank)]

    num_inputs = 0
    with Join([model]):
        for input in inputs:
            num_inputs += 1
            loss = model(input).sum()
            loss.backward()

    print(f"Rank {rank} has exhausted all {num_inputs} of its inputs!")

def main():
    mp.spawn(worker, nprocs=WORLD_SIZE, join=True)

if __name__ == "__main__":
    main()

这将产生以下输出(其中来自排名 0 和排名 1 的 print() 可能被任意排序)

Rank 0 has exhausted all 5 of its inputs!
Rank 1 has exhausted all 6 of its inputs!

注意

DistributedDataParallel 在引入这个通用的 Join 上下文管理器之前,提供了自己的 join() 上下文管理器。在上面的示例中,使用 with Join([model]): 等效于使用 with model.join():。现有的 DistributedDataParallel.join() 的一个限制是它不允许使用多个参与类,例如 DistributedDataParallelZeroRedundancyOptimizer 同时使用。

JoinDistributedDataParallelZeroRedundancyOptimizer 结合使用

Join 上下文管理器不仅适用于单个类,还适用于多个类同时使用。PyTorch 的 ZeroRedundancyOptimizer 也与上下文管理器兼容,因此在这里,我们检查如何修改前面的示例以同时使用 DistributedDataParallelZeroRedundancyOptimizer

from torch.distributed.optim import ZeroRedundancyOptimizer as ZeRO
from torch.optim import Adam

def worker(rank):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group(BACKEND, rank=rank, world_size=WORLD_SIZE)

    model = DDP(torch.nn.Linear(1, 1).to(rank), device_ids=[rank])
    optim = ZeRO(model.parameters(), Adam, lr=0.01)
    # Rank 1 gets one more input than rank 0
    inputs = [torch.tensor([1]).float() for _ in range(NUM_INPUTS + rank)]

    num_inputs = 0
    # Pass both `model` and `optim` into `Join()`
    with Join([model, optim]):
        for input in inputs:
            num_inputs += 1
            loss = model(input).sum()
            loss.backward()
            optim.step()

    print(f"Rank {rank} has exhausted all {num_inputs} of its inputs!")

这将产生与之前相同的输出。值得注意的变化是将 ZeroRedundancyOptimizer 实例额外传递到 Join() 中。

传递关键字参数

类可以提供关键字参数,在运行时修改它们在上下文管理器中的行为。例如,DistributedDataParallel 提供一个参数 divide_by_initial_world_size,它决定梯度是否按初始世界大小或有效世界大小(即非加入排名的数量)划分。此类关键字参数可以直接传递到上下文管理器中。

with Join([model, optim], divide_by_initial_world_size=False):
    for input in inputs:
        ...

警告

传递到上下文管理器的关键字参数在所有参与类之间共享。这不应该是一个限制,因为我们不期望有多个 Joinable 需要同一参数的不同设置。尽管如此,这还是需要注意的事情。

Join 是如何工作的?

既然我们已经了解了一些有关如何使用 Join 上下文管理器的初步示例,让我们更深入地了解它是如何工作的。这将为您提供对其提供的全部功能的更深入了解,并为您准备使您自己的自定义类兼容。在这里,我们将介绍 Join 类以及支持类 JoinableJoinHook

Joinable

首先,与 Join 上下文管理器兼容的类必须继承自抽象基类 Joinable。特别是,一个 Joinable 必须实现

  • join_hook(self, **kwargs) -> JoinHook

这将返回 JoinableJoinHook 实例,确定已加入的进程应该如何跟踪 Joinable 在每次训练迭代中执行的每次迭代集体通信(例如,在一次前向传递、反向传递和优化器步骤中)。

  • join_device(self) -> torch.device

这将返回一个设备,供 Join 上下文管理器用于执行集体通信,例如 torch.device("cuda:0")torch.device("cpu")

  • join_process_group(self) -> ProcessGroup

这将返回进程组,供 Join 上下文管理器用于执行集体通信。

特别是,join_devicejoin_process_group 是必需的属性,以确保上下文管理器可以在已加入和未加入的进程之间安排集体通信。一个用法是使用所有约简来计算每次迭代中未加入进程的数量。另一个用法是实现 throw_on_early_termination=True 所需的机制,我们将在下面的内容中解释。

DistributedDataParallelZeroRedundancyOptimizer 已经继承自 Joinable 并实现了上述方法,这就是为什么我们可以在前面的示例中直接使用它们的原因。

Joinable 类应该确保调用 Joinable 构造函数,因为它初始化一个 JoinConfig 实例,该实例由上下文管理器在内部使用以确保正确性。这将被保存到每个 Joinable 中,作为一个字段 _join_config

JoinHook

接下来,让我们分解 JoinHook 类。一个 JoinHook 为上下文管理器提供了两个入口点

  • main_hook(self) -> None

当存在尚未加入的排名时,每个已加入的排名会重复调用此钩子。它的目的是跟踪 Joinable 在每次训练迭代中执行的集体通信(例如,在一个前向传递、反向传递和优化器步骤中)。

  • post_hook(self, is_last_joiner: bool) -> None

当所有排名都已加入时,此钩子将被调用一次。它将传递一个额外的 bool 参数 is_last_joiner,它指示排名是否为最后加入的排名之一。该参数可能对同步有用。

为了给出这些钩子可能是什么样子的具体示例,提供的 ZeroRedundancyOptimizer 主钩子会按正常方式执行优化器步骤,因为已加入的排名仍然负责更新和同步其参数的碎片,并且提供的 DistributedDataParallel 后处理钩子将最终更新的模型从最后一个加入的排名之一广播出去,以确保它在所有排名中都相同。

Join

最后,让我们检查一下它们如何融入 Join 类本身。

  • __init__(self, joinables: List[Joinable], enable: bool = True, throw_on_early_termination: bool = False)

正如我们在前面的示例中看到的,构造函数接收参与训练循环的 Joinable 的列表。这些应该是每次迭代中执行集体通信的类。

enable 是一个 bool,如果知道没有不均匀的输入,则可以将其设置为 False,在这种情况下,上下文管理器会变得空洞,类似于 contextlib.nullcontext()。这也会禁用参与的 Joinable 中与加入相关的计算。

throw_on_early_termination 是一个 bool,可以将其设置为 True,以使每个排名在检测到不均匀的输入时立即引发异常。这对于不符合上下文管理器要求的情况非常有用,这种情况最常见于使用 DistributedDataParallel 时,模型具有 SyncBatchNorm 层。在这种情况下,此参数应该设置为 True,以便应用程序逻辑可以捕获异常并确定如何继续。

  • 核心逻辑发生在 __exit__() 方法中,该方法循环遍历所有未加入的秩,调用每个 Joinable 的主钩子,然后在所有秩都加入后,调用它们的后期钩子。主钩子和后期钩子都按照 Joinable 传入的顺序进行迭代。

  • 上下文管理器需要来自未加入进程的心跳。因此,每个 Joinable 类应该在每次迭代的集体通信之前调用 Join.notify_join_context()。上下文管理器将确保只传递的第一个 Joinable 实际上发送心跳。

警告

如上所述关于 throw_on_early_terminationJoin 上下文管理器与某些类的组合不兼容。 JoinableJoinHook 必须是可序列化的,因为每个钩子在继续执行下一个钩子之前都会完全执行。换句话说,两个钩子不能重叠。此外,目前,主钩子和后期钩子都以相同的确定性顺序进行迭代。如果这看起来是一个主要的限制,我们可能会修改 API 以允许自定义排序。

使玩具类与 Join 协同工作

由于上一节介绍了几个概念,让我们通过一个玩具示例在实践中看到它们。在这里,我们将实现一个类,该类在它的秩加入之前统计所有秩上看到的所有输入的数量。这应该提供一个基本的思路,说明如何使你自己的类与 Join 上下文管理器兼容。

具体来说,以下代码让每个秩打印出 (1) 在它加入之前所有秩看到的输入数量,以及 (2) 所有秩的总输入数量。

import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.distributed.algorithms.join import Join, Joinable, JoinHook

BACKEND = "nccl"
WORLD_SIZE = 2
NUM_INPUTS = 5

class CounterJoinHook(JoinHook):
    r"""
    Join hook for :class:`Counter`.

    Arguments:
        counter (Counter): the :class:`Counter` object using this hook.
        sync_max_count (bool): whether to sync the max count once all ranks
            join.
    """
    def __init__(
        self,
        counter,
        sync_max_count
    ):
        self.counter = counter
        self.sync_max_count = sync_max_count

    def main_hook(self):
        r"""
        Shadows the counter's all-reduce by all-reducing a dim-1 zero tensor.
        """
        t = torch.zeros(1, device=self.counter.device)
        dist.all_reduce(t)

    def post_hook(self, is_last_joiner: bool):
        r"""
        Synchronizes the max count across all :class:`Counter` s if
        ``sync_max_count=True``.
        """
        if not self.sync_max_count:
            return
        rank = dist.get_rank(self.counter.process_group)
        common_rank = self.counter.find_common_rank(rank, is_last_joiner)
        if rank == common_rank:
            self.counter.max_count = self.counter.count.detach().clone()
        dist.broadcast(self.counter.max_count, src=common_rank)

class Counter(Joinable):
    r"""
    Example :class:`Joinable` that counts the number of training iterations
    that it participates in.
    """
    def __init__(self, device, process_group):
        super(Counter, self).__init__()
        self.device = device
        self.process_group = process_group
        self.count = torch.tensor([0], device=device).float()
        self.max_count = torch.tensor([0], device=device).float()

    def __call__(self):
        r"""
        Counts the number of inputs processed on this iteration by all ranks
        by all-reducing a dim-1 one tensor; increments its own internal count.
        """
        Join.notify_join_context(self)
        t = torch.ones(1, device=self.device).float()
        dist.all_reduce(t)
        self.count += t

    def join_hook(self, **kwargs) -> JoinHook:
        r"""
        Return a join hook that shadows the all-reduce in :meth:`__call__`.

        This join hook supports the following keyword arguments:
            sync_max_count (bool, optional): whether to synchronize the maximum
                count across all ranks once all ranks join; default is ``False``.
        """
        sync_max_count = kwargs.get("sync_max_count", False)
        return CounterJoinHook(self, sync_max_count)

    @property
    def join_device(self) -> torch.device:
        return self.device

    @property
    def join_process_group(self):
        return self.process_group

    def find_common_rank(self, rank, to_consider):
        r"""
        Returns the max rank of the ones to consider over the process group.
        """
        common_rank = torch.tensor([rank if to_consider else -1], device=self.device)
        dist.all_reduce(common_rank, op=dist.ReduceOp.MAX, group=self.process_group)
        common_rank = common_rank.item()
        return common_rank

def worker(rank):
    assert torch.cuda.device_count() >= WORLD_SIZE
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group(BACKEND, rank=rank, world_size=WORLD_SIZE)

    counter = Counter(torch.device(f"cuda:{rank}"), dist.group.WORLD)
    inputs = [torch.tensor([1]).float() for _ in range(NUM_INPUTS + rank)]

    with Join([counter], sync_max_count=True):
        for _ in inputs:
            counter()

    print(f"{int(counter.count.item())} inputs processed before rank {rank} joined!")
    print(f"{int(counter.max_count.item())} inputs processed across all ranks!")

def main():
    mp.spawn(worker, nprocs=WORLD_SIZE, join=True)

if __name__ == "__main__":
    main()

由于秩 0 看到 5 个输入,秩 1 看到 6 个输入,这将产生以下输出

10 inputs processed before rank 0 joined!
11 inputs processed across all ranks!
11 inputs processed before rank 1 joined!
11 inputs processed across all ranks!

一些需要强调的关键点

  • 一个 Counter 实例每次迭代执行一次全约简,因此主钩子也执行一次全约简来对其进行遮蔽。

  • Counter 类在 __call__() 方法的开头调用 Join.notify_join_context(),因为这是在每次迭代的集体通信(即它的全约简)之前的一个地方。

  • is_last_joiner 参数用于确定后期钩子中的广播源。

  • 我们将 sync_max_count 关键字参数传递给上下文管理器,然后将其转发到 Counter 的连接钩子。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得答案

查看资源