使用 Join Context Manager 进行输入不均的分布式训练¶
创建于:2021 年 8 月 4 日 | 最后更新:2023 年 1 月 9 日 | 最后验证:2024 年 11 月 5 日
作者: Andrew Gu
注意
在 github 中查看和编辑本教程。
注意
Join
在 PyTorch 1.10 中作为原型功能引入。此 API 可能会发生更改。
在本教程中,您将看到
Join 上下文管理器的概述。
如何将上下文管理器与
DistributedDataParallel
一起使用的示例。如何将上下文管理器与
DistributedDataParallel
和ZeroRedundancyOptimizer
一起使用的示例。将关键字参数传递给上下文管理器的示例。
深入了解 Join 上下文管理器的工作原理。
一个示例,展示如何使玩具类与上下文管理器兼容。
要求¶
PyTorch 1.10+
Join
是什么?¶
在 分布式数据并行入门 - 基本用例 中,您看到了使用 DistributedDataParallel 执行数据并行训练的通用框架。这隐式地在每个反向传播中调度 all-reduce 以同步跨 rank 的梯度。此类 集体通信 需要进程组中所有 rank 的参与,因此如果一个 rank 的输入较少,则其他 rank 将挂起或出错(取决于后端)。更一般而言,对于任何执行每次迭代同步集体通信的类,此问题仍然存在。
Join
是一个上下文管理器,用于围绕您的每个 rank 训练循环,以促进输入不均的训练。上下文管理器允许提前耗尽输入的 rank(即提前加入)来遮蔽尚未加入的 rank 执行的集体通信。通信被遮蔽的方式由钩子指定。
将 Join
与 DistributedDataParallel
一起使用¶
PyTorch 的 DistributedDataParallel 可以与 Join
上下文管理器开箱即用。以下是一个示例用法
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.distributed.algorithms.join import Join
from torch.nn.parallel import DistributedDataParallel as DDP
BACKEND = "nccl"
WORLD_SIZE = 2
NUM_INPUTS = 5
def worker(rank):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '29500'
dist.init_process_group(BACKEND, rank=rank, world_size=WORLD_SIZE)
model = DDP(torch.nn.Linear(1, 1).to(rank), device_ids=[rank])
# Rank 1 gets one more input than rank 0
inputs = [torch.tensor([1]).float() for _ in range(NUM_INPUTS + rank)]
num_inputs = 0
with Join([model]):
for input in inputs:
num_inputs += 1
loss = model(input).sum()
loss.backward()
print(f"Rank {rank} has exhausted all {num_inputs} of its inputs!")
def main():
mp.spawn(worker, nprocs=WORLD_SIZE, join=True)
if __name__ == "__main__":
main()
这将产生以下输出(其中来自 rank 0 和 rank 1 的 print()
s 可能以任意顺序排列)
Rank 0 has exhausted all 5 of its inputs!
Rank 1 has exhausted all 6 of its inputs!
注意
DistributedDataParallel 在引入此通用 Join
上下文管理器之前,提供了自己的 join() 上下文管理器。在上面的示例中,使用 with Join([model]):
等效于使用 with model.join():
。现有 DistributedDataParallel.join()
的一个限制是它不允许多个参与类,例如 DistributedDataParallel
和 ZeroRedundancyOptimizer 一起。
将 Join
与 DistributedDataParallel
和 ZeroRedundancyOptimizer
一起使用¶
Join
上下文管理器不仅可以与单个类一起使用,还可以与多个类一起使用。PyTorch 的 ZeroRedundancyOptimizer
也与上下文管理器兼容,因此,在这里,我们研究如何修改之前的示例以同时使用 DistributedDataParallel
和 ZeroRedundancyOptimizer
from torch.distributed.optim import ZeroRedundancyOptimizer as ZeRO
from torch.optim import Adam
def worker(rank):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '29500'
dist.init_process_group(BACKEND, rank=rank, world_size=WORLD_SIZE)
model = DDP(torch.nn.Linear(1, 1).to(rank), device_ids=[rank])
optim = ZeRO(model.parameters(), Adam, lr=0.01)
# Rank 1 gets one more input than rank 0
inputs = [torch.tensor([1]).float() for _ in range(NUM_INPUTS + rank)]
num_inputs = 0
# Pass both `model` and `optim` into `Join()`
with Join([model, optim]):
for input in inputs:
num_inputs += 1
loss = model(input).sum()
loss.backward()
optim.step()
print(f"Rank {rank} has exhausted all {num_inputs} of its inputs!")
这将产生与之前相同的输出。值得注意的变化是额外将 ZeroRedundancyOptimizer
实例传递到 Join()
中。
传递关键字参数¶
类可以提供关键字参数,这些参数在运行时修改它们在上下文管理器中的行为。例如,DistributedDataParallel
提供了一个参数 divide_by_initial_world_size
,它确定梯度是否除以初始世界大小或有效世界大小(即,未加入 rank 的数量)。此类关键字参数可以直接传递到上下文管理器中。
with Join([model, optim], divide_by_initial_world_size=False):
for input in inputs:
...
警告
传递到上下文管理器的关键字参数在所有参与类之间共享。这不应成为限制,因为我们不希望出现多个 Joinable
需要相同参数的不同设置的情况。尽管如此,这是需要牢记的事情。
Join
如何工作?¶
现在我们已经看到了一些关于如何使用 Join
上下文管理器的初步示例,让我们深入研究它的工作原理。这将更深入地了解它提供的全部功能,并让您准备好使您自己的自定义类兼容。在这里,我们将介绍 Join
类以及支持类 Joinable
和 JoinHook
。
Joinable
¶
首先,与 Join
上下文管理器兼容的类必须从抽象基类 Joinable
继承。特别是,Joinable
必须实现
join_hook(self, **kwargs) -> JoinHook
这会返回 Joinable
的 JoinHook
实例,确定已加入的进程应如何遮蔽 Joinable
执行的每次迭代集体通信。
join_device(self) -> torch.device
这会返回 Join
上下文管理器使用的设备,以执行集体通信,例如 torch.device("cuda:0")
或 torch.device("cpu")
。
join_process_group(self) -> ProcessGroup
这会返回 Join
上下文管理器使用的进程组,以执行集体通信。
特别是,join_device
和 join_process_group
是必需属性,以确保上下文管理器可以调度已加入和未加入进程之间的集体通信。一种用法是使用 all-reduce 来计算每次迭代中未加入进程的数量。另一种用法是实现 throw_on_early_termination=True
所需的机制,我们将在下面稍后解释。
DistributedDataParallel
和 ZeroRedundancyOptimizer
已经从 Joinable
继承并实现了上述方法,这就是为什么我们可以在之前的示例中直接使用它们的原因。
Joinable
类应确保调用 Joinable
构造函数,因为它初始化了一个 JoinConfig
实例,上下文管理器在内部使用该实例来确保正确性。这将保存在每个 Joinable
中,作为字段 _join_config
。
JoinHook
¶
接下来,让我们分解 JoinHook
类。JoinHook
为上下文管理器提供两个入口点
main_hook(self) -> None
当存在尚未加入的 rank 时,已加入的每个 rank 会重复调用此钩子。它旨在遮蔽 Joinable
在每次训练迭代(例如,在一个前向传播、反向传播和优化器步骤中)中执行的集体通信。
post_hook(self, is_last_joiner: bool) -> None
当所有 rank 都已加入后,将调用此钩子。它传递一个额外的 bool
参数 is_last_joiner
,指示该 rank 是否是最后加入的 rank 之一。该参数可能对同步有用。
为了给出这些钩子可能是什么样子的具体示例,提供的 ZeroRedundancyOptimizer
主钩子像往常一样执行优化器步骤,因为已加入的 rank 仍然负责更新和同步其参数分片,而提供的 DistributedDataParallel
后钩子从最后加入的 rank 之一广播最终更新的模型,以确保所有 rank 上的模型都相同。
Join
¶
最后,让我们检查一下这些如何融入 Join
类本身。
__init__(self, joinables: List[Joinable], enable: bool = True, throw_on_early_termination: bool = False)
正如我们在之前的示例中看到的那样,构造函数接受参与训练循环的 Joinable
s 列表。这些应该是每次迭代中执行集体通信的类。
enable
是一个 bool
,如果您知道不会有输入不均的情况,则可以将其设置为 False
,在这种情况下,上下文管理器将变得空洞,类似于 contextlib.nullcontext()
。这也可能会禁用参与 Joinable
s 中的与 join 相关的计算。
throw_on_early_termination
是一个 bool
,可以将其设置为 True
,以便在检测到输入不均的那一刻,每个 rank 都会引发异常。这对于不符合上下文管理器要求的案例很有用,最常见的情况是当来自不同类的集体通信可能任意交错时,例如当将 DistributedDataParallel
与具有 SyncBatchNorm
层的模型一起使用时。在这种情况下,应将此参数设置为 True
,以便应用程序逻辑可以捕获异常并确定如何继续。
核心逻辑发生在
__exit__()
方法中,该方法在存在未加入的 rank 时循环,调用每个Joinable
的主钩子,然后在所有 rank 都已加入后,调用它们的后钩子。主钩子和后钩子都按照Joinable
s 传入的顺序迭代。上下文管理器需要来自未加入进程的心跳。因此,每个
Joinable
类都应在其每次迭代集体通信之前调用Join.notify_join_context()
。上下文管理器将确保只有第一个传入的Joinable
实际发送心跳。
警告
如上所述关于 throw_on_early_termination
,Join
上下文管理器与某些类的组合不兼容。Joinable
的 JoinHook
s 必须是可序列化的,因为每个钩子在继续下一个钩子之前都会完全执行。换句话说,两个钩子不能重叠。此外,目前,主钩子和后钩子都以相同的确定性顺序迭代。如果这似乎是一个主要限制,我们可能会修改 API 以允许自定义排序。
使玩具类与 Join
一起工作¶
由于上一节介绍了几个概念,让我们通过一个玩具示例在实践中看看它们。在这里,我们将实现一个类,该类计算在其 rank 加入之前在所有 rank 中看到的输入数量。这应该提供一个基本概念,说明如何使您自己的类与 Join
上下文管理器兼容。
具体来说,以下代码使每个 rank 打印出 (1) 在其加入之前在所有 rank 中看到的输入数量,以及 (2) 所有 rank 中输入的总数。
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.distributed.algorithms.join import Join, Joinable, JoinHook
BACKEND = "nccl"
WORLD_SIZE = 2
NUM_INPUTS = 5
class CounterJoinHook(JoinHook):
r"""
Join hook for :class:`Counter`.
Arguments:
counter (Counter): the :class:`Counter` object using this hook.
sync_max_count (bool): whether to sync the max count once all ranks
join.
"""
def __init__(
self,
counter,
sync_max_count
):
self.counter = counter
self.sync_max_count = sync_max_count
def main_hook(self):
r"""
Shadows the counter's all-reduce by all-reducing a dim-1 zero tensor.
"""
t = torch.zeros(1, device=self.counter.device)
dist.all_reduce(t)
def post_hook(self, is_last_joiner: bool):
r"""
Synchronizes the max count across all :class:`Counter` s if
``sync_max_count=True``.
"""
if not self.sync_max_count:
return
rank = dist.get_rank(self.counter.process_group)
common_rank = self.counter.find_common_rank(rank, is_last_joiner)
if rank == common_rank:
self.counter.max_count = self.counter.count.detach().clone()
dist.broadcast(self.counter.max_count, src=common_rank)
class Counter(Joinable):
r"""
Example :class:`Joinable` that counts the number of training iterations
that it participates in.
"""
def __init__(self, device, process_group):
super(Counter, self).__init__()
self.device = device
self.process_group = process_group
self.count = torch.tensor([0], device=device).float()
self.max_count = torch.tensor([0], device=device).float()
def __call__(self):
r"""
Counts the number of inputs processed on this iteration by all ranks
by all-reducing a dim-1 one tensor; increments its own internal count.
"""
Join.notify_join_context(self)
t = torch.ones(1, device=self.device).float()
dist.all_reduce(t)
self.count += t
def join_hook(self, **kwargs) -> JoinHook:
r"""
Return a join hook that shadows the all-reduce in :meth:`__call__`.
This join hook supports the following keyword arguments:
sync_max_count (bool, optional): whether to synchronize the maximum
count across all ranks once all ranks join; default is ``False``.
"""
sync_max_count = kwargs.get("sync_max_count", False)
return CounterJoinHook(self, sync_max_count)
@property
def join_device(self) -> torch.device:
return self.device
@property
def join_process_group(self):
return self.process_group
def find_common_rank(self, rank, to_consider):
r"""
Returns the max rank of the ones to consider over the process group.
"""
common_rank = torch.tensor([rank if to_consider else -1], device=self.device)
dist.all_reduce(common_rank, op=dist.ReduceOp.MAX, group=self.process_group)
common_rank = common_rank.item()
return common_rank
def worker(rank):
assert torch.cuda.device_count() >= WORLD_SIZE
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '29500'
dist.init_process_group(BACKEND, rank=rank, world_size=WORLD_SIZE)
counter = Counter(torch.device(f"cuda:{rank}"), dist.group.WORLD)
inputs = [torch.tensor([1]).float() for _ in range(NUM_INPUTS + rank)]
with Join([counter], sync_max_count=True):
for _ in inputs:
counter()
print(f"{int(counter.count.item())} inputs processed before rank {rank} joined!")
print(f"{int(counter.max_count.item())} inputs processed across all ranks!")
def main():
mp.spawn(worker, nprocs=WORLD_SIZE, join=True)
if __name__ == "__main__":
main()
由于 rank 0 看到 5 个输入,rank 1 看到 6 个输入,因此产生以下输出
10 inputs processed before rank 0 joined!
11 inputs processed across all ranks!
11 inputs processed before rank 1 joined!
11 inputs processed across all ranks!
要突出显示的一些关键点
Counter
实例每次迭代执行一次 all-reduce,因此主钩子也执行一次 all-reduce 以遮蔽它。Counter
类在其__call__()
方法的开头调用Join.notify_join_context()
,因为那是其每次迭代集体通信(即其 all-reduce)之前的位置。is_last_joiner
参数用于确定后钩子中的广播源。我们将
sync_max_count
关键字参数传递给上下文管理器,然后将其转发到Counter
的 join 钩子。