管道并行¶
管道并行最初在 Gpipe 论文中提出,是一种在多个 GPU 上训练大型模型的高效技术。
警告
torch.distributed.pipeline 已弃用,本文档也已弃用。有关最新的管道并行实现,请参阅 PyTorch 组织下的 PiPPy 库(PyTorch 的管道并行)。
使用多个 GPU 的模型并行¶
通常对于不适合单个 GPU 的大型模型,会采用模型并行,其中模型的某些部分放置在不同的 GPU 上。但是,如果对顺序模型天真地执行此操作,则训练过程会因 GPU 利用率不足而受到影响,因为一次只激活一个 GPU,如下面的图片所示
流水线执行¶
为了缓解此问题,管道并行将输入小批量拆分为多个微小批量,并将这些微小批量的执行流水线化到多个 GPU 上。这在下面的图片中概述
PyTorch 中的管道 API¶
- 类 torch.distributed.pipeline.sync.Pipe(module, chunks=1, checkpoint='except_last', deferred_batch_norm=False)[源代码]¶
包装任意
nn.Sequential
模块以使用同步管道并行进行训练。如果模块需要大量内存且无法容纳在单个 GPU 上,则管道并行是一种可用于训练的有用技术。该实现基于 torchgpipe 论文。
Pipe 将管道并行与检查点结合起来,以减少训练所需的峰值内存,同时最大程度地减少设备未充分利用的情况。
您应该将所有模块放在适当的设备上,并将它们包装到
nn.Sequential
模块中,定义所需的执行顺序。如果模块不包含任何参数/缓冲区,则假定此模块应在 CPU 上执行,并且在执行之前将模块的适当输入张量移动到 CPU。此行为可以通过WithDevice
包装器覆盖,该包装器可用于明确指定模块应在哪个设备上运行。- 参数
module (
nn.Sequential
) – 使用管道并行化的顺序模块。序列中的每个模块都必须将其所有参数放在单个设备上。序列中的每个模块必须是 nn.Module 或nn.Sequential
(以将多个顺序模块组合在单个设备上)chunks (int) – 微批次数量(默认值:
1
)checkpoint (str) – 何时启用检查点,
'always'
、'except_last'
或'never'
之一(默认值:'except_last'
)。'never'
完全禁用检查点,'except_last'
为除最后一个之外的所有微批次启用检查点,'always'
为所有微批次启用检查点。deferred_batch_norm (bool) – 是否使用延迟的
BatchNorm
移动统计信息(默认值:False
)。如果设置为True
,我们将在多个微批次中跟踪统计信息,以按每个小批次更新运行统计信息。
- 引发
TypeError – 模块不是
nn.Sequential
。ValueError – 无效参数
- 示例:
跨 GPU 0 和 1 的两个 FC 层管道。
>>> # Need to initialize RPC framework first. >>> os.environ['MASTER_ADDR'] = 'localhost' >>> os.environ['MASTER_PORT'] = '29500' >>> torch.distributed.rpc.init_rpc('worker', rank=0, world_size=1) >>> >>> # Build pipe. >>> fc1 = nn.Linear(16, 8).cuda(0) >>> fc2 = nn.Linear(8, 4).cuda(1) >>> model = nn.Sequential(fc1, fc2) >>> model = Pipe(model, chunks=8) >>> input = torch.rand(16, 16).cuda(0) >>> output_rref = model(input)
注意
仅当
Pipe
模型的 checkpoint 参数为'never'
时,才能使用torch.nn.parallel.DistributedDataParallel
封装Pipe
。注意
Pipe
目前仅支持节点内流水线,但未来将扩展为支持节点间流水线。forward 函数返回一个RRef
,以便将来允许节点间流水线,其中输出可能位于远程主机上。对于节点内流水线,可以使用local_value()
在本地检索输出。警告
Pipe
是实验性的,可能会发生变化。- forward(*inputs)[源代码]¶
通过管道处理单个输入小批量,并返回指向输出的
RRef
。Pipe
是一个相当透明的模块包装器。它不会修改底层模块的输入和输出签名。但存在类型限制。输入和输出必须至少包含一个张量。此限制也应用于分区边界。输入序列作为
*inputs
馈送到管道的第一个阶段。因此,此函数的位置参数应与管道第一个阶段的位置参数匹配。对于管道的某个阶段的输出(即下一个阶段的输入),也适用相同条件。输入张量根据用于初始化
Pipe
的chunks
参数拆分为多个微小批量。假设批大小是张量的第一个维度,如果批大小小于chunks
,则微小批量的数量等于批大小。只有张量被拆分为多个微批次,非张量输入在每个微批次中按原样复制。对于管道最后阶段的非张量输出,它们被聚合为
List
并返回给用户。例如,如果您有 2 个微批次返回整数 5,则用户将收到合并后的输出 [5, 5]所有输入张量需要与管道的第一个分区位于同一设备上。
如果张量用
NoChunk
包装器包装,则该张量不会在微批次之间拆分,并且会按原样复制,类似于非张量。- 参数
inputs – 输入小批量
- 返回
RRef
到小批量的输出- 引发
TypeError – 输入不包含至少一个张量
- 返回类型
RRef
跳过连接¶
某些模型(如 ResNeXt)不是完全顺序的,并且在层之间具有跳过连接。天真地作为管道并行的一部分实施意味着我们需要通过多个 GPU 复制某些层的输出,直到最终到达跳过连接所在的层的 GPU。为了避免这种复制开销,我们在下面提供了 API,以在模型的不同层中隐藏和弹出张量。
- torch.distributed.pipeline.sync.skip.skippable.skippable(stash=(), pop=())[source]¶
定义一个装饰器来创建具有跳过连接的
nn.Module
。这些装饰模块称为“可跳过”。即使模块未被
Pipe
包装,此功能也能完美运行。每个跳过张量都由其名称管理。在操作跳过张量之前,可跳过模块必须通过 stash 和/或 pop 参数静态声明跳过张量的名称。具有预声明名称的跳过张量可以通过
yield stash(name, tensor)
存储,或通过tensor = yield pop(name)
弹出。下面是一个包含三层示例。一个名为“1to3”的跳过张量分别在第一层和最后一层存储和弹出
@skippable(stash=['1to3']) class Layer1(nn.Module): def forward(self, input): yield stash('1to3', input) return f1(input) class Layer2(nn.Module): def forward(self, input): return f2(input) @skippable(pop=['1to3']) class Layer3(nn.Module): def forward(self, input): skip_1to3 = yield pop('1to3') return f3(input) + skip_1to3 model = nn.Sequential(Layer1(), Layer2(), Layer3())
一个可跳过模块可以存储或弹出多个跳过张量
@skippable(stash=['alice', 'bob'], pop=['carol']) class StashStashPop(nn.Module): def forward(self, input): yield stash('alice', f_alice(input)) yield stash('bob', f_bob(input)) carol = yield pop('carol') return input + carol
每个跳过张量都必须与一对 stash 和 pop 关联。在包装模块时,
Pipe
会自动检查此限制。你还可以通过verify_skippables()
检查限制,而无需Pipe
。
- class torch.distributed.pipeline.sync.skip.skippable.stash(name, tensor)[source]¶
存储跳过张量的命令。
def forward(self, input): yield stash('name', input) return f(input)
- 参数
name (str) – 跳过张量的名称
input (torch.Tensor 或 None) – 传递到跳过连接的张量
- 类 torch.distributed.pipeline.sync.skip.skippable.pop(name)[源代码]¶
弹出跳过张量的命令。
def forward(self, input): skip = yield pop('name') return f(input) + skip
- 参数
name (str) – 跳过张量的名称
- 返回
由另一个层在相同名称下先前隐藏的跳过张量
- 返回类型
无
- torch.distributed.pipeline.sync.skip.skippable.verify_skippables(module)[源代码]¶
验证底层可跳过模块是否满足完整性。
每个跳过张量必须只有一对 stash 和 pop。如果存在一对或多对不匹配的,它将引发
TypeError
,其中包含详细消息。以下是一些失败案例。
verify_skippables()
将报告这些案例的失败# Layer1 stashes "1to3". # Layer3 pops "1to3". nn.Sequential(Layer1(), Layer2()) # └──── ? nn.Sequential(Layer2(), Layer3()) # ? ────┘ nn.Sequential(Layer1(), Layer2(), Layer3(), Layer3()) # └───────────────────┘ ^^^^^^ nn.Sequential(Layer1(), Layer1(), Layer2(), Layer3()) # ^^^^^^ └───────────────────┘
要对多个跳过张量使用相同的名称,它们必须通过不同的命名空间进行隔离。请参阅
isolate()
。- 引发
TypeError – 一对或多对 stash 和 pop 不匹配。
鸣谢¶
管道并行的实现基于 fairscale 的管道实现 和 torchgpipe。我们要感谢这两个团队为将管道并行引入 PyTorch 做出的贡献和指导。