流水线并行¶
注意
torch.distributed.pipelining
目前处于 alpha 状态,正在开发中。API 可能会更改。它从 PiPPy 项目迁移而来。
为什么选择流水线并行?¶
流水线并行是深度学习的基本并行方法之一。它允许对模型执行进行分区,以便多个微批次可以并发执行模型代码的不同部分。对于以下情况,流水线并行可能是一种有效的技术:
大规模训练
带宽受限的集群
大型模型推理
以上场景的共同点是,每个设备的计算量无法隐藏传统并行方法的通信开销,例如,FSDP 的权重全收集 (all-gather)。
torch.distributed.pipelining
是什么?¶
虽然流水线并行在扩展方面前景广阔,但它通常难以实现,因为它除了模型权重之外,还需要划分模型执行。执行划分通常需要对模型进行侵入式代码更改。复杂性的另一个方面来自于在分布式环境中调度微批次,并考虑数据流依赖性。
pipelining
包提供了一个工具包,可以自动完成上述操作,从而可以轻松地在通用模型上实现流水线并行。
它由两部分组成:一个拆分前端和一个分布式运行时。拆分前端接受您的模型代码,将其拆分为“模型分区”,并捕获数据流关系。分布式运行时在不同设备上并行执行流水线阶段,处理微批次拆分、调度、通信和梯度传播等。
总的来说,pipelining
包提供以下功能:
基于简单规范的模型代码拆分。
对流水线调度的丰富支持,包括 GPipe、1F1B、交错 1F1B 和循环 BFS,并为编写自定义调度提供基础设施。
对跨主机流水线并行的首要支持,因为这通常是 PP 的使用场景(在较慢的互连上)。
与其他 PyTorch 并行技术(如数据并行 (DDP, FSDP) 或张量并行)的可组合性。TorchTitan 项目演示了 Llama 模型上的“3D 并行”应用。
步骤 1:构建 PipelineStage
¶
在使用 PipelineSchedule
之前,我们需要创建 PipelineStage
对象,这些对象封装了在阶段中运行的模型部分。PipelineStage
负责分配通信缓冲区,并创建发送/接收操作以与其对等方通信。它管理中间缓冲区,例如尚未使用的前向传播输出,并提供实用程序来运行阶段模型的反向传播。
PipelineStage
需要知道阶段模型的输入和输出形状,以便它可以正确分配通信缓冲区。形状必须是静态的,例如,在运行时,形状不能在步骤之间更改。如果运行时形状与预期形状不匹配,则会引发 PipeliningShapeError
类。当与其他并行性组合或应用混合精度时,必须考虑这些技术,以便 PipelineStage
知道运行时阶段模块输出的正确形状(和 dtype)。
用户可以直接构造 PipelineStage
实例,方法是传入一个 nn.Module
,表示应在该阶段上运行的模型部分。这可能需要更改原始模型代码。请参阅 选项 1:手动拆分模型 中的示例。
或者,拆分前端可以使用图分区自动将模型拆分为一系列 nn.Module
。此技术要求模型可使用 torch.Export
进行追踪。生成的 nn.Module
与其他并行技术的可组合性是实验性的,可能需要一些解决方法。如果用户不容易更改模型代码,则使用此前端可能更具吸引力。有关更多信息,请参阅 选项 2:自动拆分模型。
步骤 2:使用 PipelineSchedule
执行¶
我们现在可以将 PipelineStage
附加到流水线调度,并使用输入数据运行调度。这是一个 GPipe 示例
from torch.distributed.pipelining import ScheduleGPipe
# Create a schedule
schedule = ScheduleGPipe(stage, n_microbatches)
# Input data (whole batch)
x = torch.randn(batch_size, in_dim, device=device)
# Run the pipeline with input `x`
# `x` will be divided into microbatches automatically
if rank == 0:
schedule.step(x)
else:
output = schedule.step()
请注意,上述代码需要为每个 worker 启动,因此我们使用启动器服务来启动多个进程
torchrun --nproc_per_node=2 example.py
拆分模型的选项¶
选项 1:手动拆分模型¶
要直接构造 PipelineStage
,用户负责提供单个 nn.Module
实例,该实例拥有相关的 nn.Parameters
和 nn.Buffers
,并定义一个 forward()
方法,该方法执行该阶段相关的操作。例如,Torchtitan 中定义的 Transformer 类的精简版本展示了构建易于分区的模型的模式。
class Transformer(nn.Module):
def __init__(self, model_args: ModelArgs):
super().__init__()
self.tok_embeddings = nn.Embedding(...)
# Using a ModuleDict lets us delete layers without affecting names,
# ensuring checkpoints will correctly save and load.
self.layers = torch.nn.ModuleDict()
for layer_id in range(model_args.n_layers):
self.layers[str(layer_id)] = TransformerBlock(...)
self.output = nn.Linear(...)
def forward(self, tokens: torch.Tensor):
# Handling layers being 'None' at runtime enables easy pipeline splitting
h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens
for layer in self.layers.values():
h = layer(h, self.freqs_cis)
h = self.norm(h) if self.norm else h
output = self.output(h).float() if self.output else h
return output
以这种方式定义的模型可以很容易地按阶段配置,方法是首先初始化整个模型(使用 meta-device 以避免 OOM 错误),删除该阶段不需要的层,然后创建一个 PipelineStage 来封装模型。例如
with torch.device("meta"):
assert num_stages == 2, "This is a simple 2-stage example"
# we construct the entire model, then delete the parts we do not need for this stage
# in practice, this can be done using a helper function that automatically divides up layers across stages.
model = Transformer()
if stage_index == 0:
# prepare the first stage model
del model.layers["1"]
model.norm = None
model.output = None
elif stage_index == 1:
# prepare the second stage model
model.tok_embeddings = None
del model.layers["0"]
from torch.distributed.pipelining import PipelineStage
stage = PipelineStage(
model,
stage_index,
num_stages,
device,
input_args=example_input_microbatch,
)
PipelineStage
需要一个示例参数 input_args
,表示阶段的运行时输入,这将是一个微批次的输入数据。此参数通过阶段模块的 forward 方法传递,以确定通信所需的输入和输出形状。
当与其他数据或模型并行技术组合时,如果模型块的输出形状/dtype 将受到影响,则也可能需要 output_args
。
选项 2:自动拆分模型¶
如果您有一个完整的模型,并且不想花费时间将其修改为一系列“模型分区”,则 pipeline
API 可以提供帮助。这是一个简短的示例
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.emb = torch.nn.Embedding(10, 3)
self.layers = torch.nn.ModuleList(
Layer() for _ in range(2)
)
self.lm = LMHead()
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.emb(x)
for layer in self.layers:
x = layer(x)
x = self.lm(x)
return x
如果我们打印模型,我们可以看到多个层次结构,这使得手动拆分很困难
Model(
(emb): Embedding(10, 3)
(layers): ModuleList(
(0-1): 2 x Layer(
(lin): Linear(in_features=3, out_features=3, bias=True)
)
)
(lm): LMHead(
(proj): Linear(in_features=3, out_features=3, bias=True)
)
)
让我们看看 pipeline
API 是如何工作的
from torch.distributed.pipelining import pipeline, SplitPoint
# An example micro-batch input
x = torch.LongTensor([1, 2, 4, 5])
pipe = pipeline(
module=mod,
mb_args=(x,),
split_spec={
"layers.1": SplitPoint.BEGINNING,
}
)
pipeline
API 根据 split_spec
拆分您的模型,其中 SplitPoint.BEGINNING
表示在 forward
函数中执行某些子模块之前添加拆分点,类似地,SplitPoint.END
表示在执行之后添加拆分点。
如果我们 print(pipe)
,我们可以看到
GraphModule(
(submod_0): GraphModule(
(emb): InterpreterModule()
(layers): Module(
(0): InterpreterModule(
(lin): InterpreterModule()
)
)
)
(submod_1): GraphModule(
(layers): Module(
(1): InterpreterModule(
(lin): InterpreterModule()
)
)
(lm): InterpreterModule(
(proj): InterpreterModule()
)
)
)
def forward(self, x):
submod_0 = self.submod_0(x); x = None
submod_1 = self.submod_1(submod_0); submod_0 = None
return (submod_1,)
“模型分区”由子模块 (submod_0
, submod_1
) 表示,每个子模块都使用原始模型操作、权重和层次结构重建。此外,重建了一个“根级别”forward
函数,以捕获这些分区之间的数据流。此类数据流稍后将由流水线运行时以分布式方式重放。
Pipe
对象提供了一种检索“模型分区”的方法
stage_mod : nn.Module = pipe.get_stage_module(stage_idx)
返回的 stage_mod
是一个 nn.Module
,您可以使用它来创建优化器、保存或加载检查点,或应用其他并行性。
Pipe
还允许您在给定 ProcessGroup
的情况下在设备上创建分布式阶段运行时
stage = pipe.build_stage(stage_idx, device, group)
或者,如果您想在对 stage_mod
进行一些修改后稍后构建阶段运行时,则可以使用 build_stage
API 的函数式版本。例如
from torch.distributed.pipelining import build_stage
from torch.nn.parallel import DistributedDataParallel
dp_mod = DistributedDataParallel(stage_mod)
info = pipe.info()
stage = build_stage(dp_mod, stage_idx, info, device, group)
注意
pipeline
前端使用追踪器 (torch.export
) 将您的模型捕获到单个图中。如果您的模型不是完全可图化的,您可以使用下面的手动前端。
Hugging Face 示例¶
在本软件包最初创建的 PiPPy 代码仓库中,我们保留了基于未修改的 Hugging Face 模型的示例。请参阅 examples/huggingface 目录。
示例包括:
技术深入探讨¶
pipeline
API 如何拆分模型?¶
首先,pipeline
API 通过追踪模型将我们的模型转换为有向无环图 (DAG)。它使用 torch.export
(PyTorch 2 全图捕获工具)追踪模型。
然后,它将一个阶段所需的操作和参数组合在一起,形成一个重建的子模块:submod_0
、submod_1
、...
与传统的子模块访问方法(如 Module.children()
)不同,pipeline
API 不仅切割模型的模块结构,还切割模型的 forward 函数。
这是必要的,因为像 Module.children()
这样的模型结构仅仅捕获 Module.__init__()
期间的信息,而不捕获关于 Module.forward()
的任何信息。换句话说,Module.children()
缺少关于流水线并行关键方面的以下信息:
子模块在
forward
中的执行顺序子模块之间的激活流
子模块之间是否存在任何函数式算子(例如,
relu
或add
操作不会被Module.children()
捕获)。
相反,pipeline
API 确保真正保留 forward
行为。它还捕获分区之间的激活流,帮助分布式运行时进行正确的发送/接收调用,而无需人工干预。
pipeline
API 的另一个灵活性是,拆分点可以位于模型层次结构中的任意级别。在拆分的分区中,与该分区相关的原始模型层次结构将免费重建。因此,指向子模块或参数的完全限定名称 (FQNs) 仍然有效,并且依赖于 FQNs 的服务(例如 FSDP、TP 或检查点)仍然可以在代码几乎零更改的情况下与您的分区模块一起运行。
实现您自己的调度¶
您可以通过扩展以下两个类之一来实现您自己的流水线调度:
PipelineScheduleSingle
PipelineScheduleMulti
PipelineScheduleSingle
适用于每个 rank 仅分配一个阶段的调度。PipelineScheduleMulti
适用于每个 rank 分配多个阶段的调度。
例如,ScheduleGPipe
和 Schedule1F1B
是 PipelineScheduleSingle
的子类。而 ScheduleInterleaved1F1B
、ScheduleLoopedBFS
、ScheduleInterleavedZeroBubble
和 ScheduleZBVZeroBubble
是 PipelineScheduleMulti
的子类。
日志记录¶
您可以使用来自 [torch._logging](https://pytorch.ac.cn/docs/main/logging.html#module-torch._logging) 的 TORCH_LOGS 环境变量打开额外的日志记录
TORCH_LOGS=+pp 将显示 logging.DEBUG 消息和所有更高级别的消息。
TORCH_LOGS=pp 将显示 logging.INFO 消息和更高级别的消息。
TORCH_LOGS=-pp 将显示 logging.WARNING 消息和更高级别的消息。
API 参考¶
模型拆分 API¶
以下 API 集将您的模型转换为流水线表示。
- torch.distributed.pipelining.pipeline(module, mb_args, mb_kwargs=None, split_spec=None, split_policy=None)[source][source]¶
根据规范拆分模块。
有关更多详细信息,请参阅 Pipe。
- 参数:
module (Module) – 要拆分的模块。
mb_kwargs (Optional[Dict[str, Any]]) – 微批次形式的示例关键字输入。(默认值:None)
split_spec (Optional[Dict[str, SplitPoint]]) – 使用子模块名称作为拆分标记的字典。(默认值:None)
split_policy (Optional[Callable[[GraphModule], GraphModule]]) – 用于拆分模块的策略。(默认值:None)
- 返回类型:
类 Pipe 的流水线表示。
微批次实用程序¶
- class torch.distributed.pipelining.microbatch.TensorChunkSpec(split_dim)[source][source]¶
用于指定输入分块的类
Pipeline Stages¶
- class torch.distributed.pipelining.stage.PipelineStage(submodule, stage_index, num_stages, device, input_args=None, output_args=None, group=None, dw_builder=None)[source][source]¶
表示流水线并行设置中的流水线阶段的类。
PipelineStage 假设模型的顺序分区,即模型被分割成多个块,其中一个块的输出馈送到下一个块的输入,没有跳跃连接。
PipelineStage 通过以线性顺序将 stage0 的输出传播到 stage1 等等,自动执行运行时形状/dtype 推断。 要绕过形状推断,请将 input_args 和 output_args 传递给每个 PipelineStage 实例。
- 参数:
submodule (nn.Module) – 此阶段包装的 PyTorch 模块。
stage_index (int) – 此阶段的 ID。
num_stages (int) – 阶段总数。
device (torch.device) – 此阶段所在的设备。
input_args (Union[torch.Tensor, Tuple[torch.tensor]], optional) – 子模块的输入参数。
output_args (Union[torch.Tensor, Tuple[torch.tensor]], optional) – 子模块的输出参数。
group (dist.ProcessGroup, optional) – 用于分布式训练的进程组。 如果为 None,则为默认组。
dw_builder (Optional[Callable[[], Callable[[...], None]]]) – TODO 清理注释
- torch.distributed.pipelining.stage.build_stage(stage_module, stage_index, pipe_info, device, group=None)[source][source]¶
给定要由此阶段包装的 stage_module 和流水线信息,创建一个流水线阶段。
- 参数:
stage_module (torch.nn.Module) – 要由此阶段包装的模块
stage_index (int) – 此阶段在流水线中的索引
pipe_info (PipeInfo) – 关于流水线的信息,可以通过 pipe.info() 检索
device (torch.device) – 此阶段要使用的设备
group (Optional[dist.ProcessGroup]) – 此阶段要使用的进程组
- Returns
一个可以与 PipelineSchedules 一起运行的流水线阶段。
- 返回类型:
_PipelineStage
Pipeline Schedules¶
- class torch.distributed.pipelining.schedules.ScheduleGPipe(stage, n_microbatches, loss_fn=None, args_chunk_spec=None, kwargs_chunk_spec=None, output_merge_spec=None)[source][source]¶
GPipe 调度。 将以填充-耗尽的方式遍历所有微批次。
- class torch.distributed.pipelining.schedules.Schedule1F1B(stage, n_microbatches, loss_fn=None, args_chunk_spec=None, kwargs_chunk_spec=None, output_merge_spec=None)[source][source]¶
1F1B 调度。 将在稳态下对微批次执行一次前向和一次反向传播。
- class torch.distributed.pipelining.schedules.ScheduleInterleaved1F1B(stages, n_microbatches, loss_fn=None, args_chunk_spec=None, kwargs_chunk_spec=None, output_merge_spec=None)[source][source]¶
交错 1F1B 调度。 有关详细信息,请参阅 https://arxiv.org/pdf/2104.04473。 将在稳态下对微批次执行一次前向和一次反向传播,并支持每个 rank 多个阶段。 当微批次准备好用于多个本地阶段时,交错 1F1B 优先考虑较早的微批次(也称为“深度优先”)。
此调度程序与原始论文中的调度程序非常相似。 它的不同之处在于放宽了 num_microbatch % pp_size == 0 的要求。 使用 flex_pp 调度程序,我们将有 num_rounds = max(1, n_microbatches // pp_group_size),并且只要 n_microbatches % num_rounds 为 0 即可工作。 举几个例子,支持
pp_group_size = 4, n_microbatches = 10。我们将有 num_rounds = 2,并且 n_microbatches % 2 为 0。
pp_group_size = 4, n_microbatches = 3。我们将有 num_rounds = 1,并且 n_microbatches % 1 为 0。
- class torch.distributed.pipelining.schedules.ScheduleLoopedBFS(stages, n_microbatches, loss_fn=None, output_merge_spec=None)[source][source]¶
广度优先流水线并行。 有关详细信息,请参阅 https://arxiv.org/abs/2211.05953。 与交错 1F1B 类似,循环 BFS 支持每个 rank 多个阶段。 不同之处在于,当微批次准备好用于多个本地阶段时,循环 BFS 将优先考虑较早的阶段,一次运行所有可用的微批次。
- class torch.distributed.pipelining.schedules.ScheduleInterleavedZeroBubble(stages, n_microbatches, loss_fn=None, args_chunk_spec=None, kwargs_chunk_spec=None, output_merge_spec=None)[source][source]¶
交错零气泡调度。 有关详细信息,请参阅 https://arxiv.org/pdf/2401.10241。 将在稳态下对微批次的输入执行一次前向和一次反向传播,并支持每个 rank 多个阶段。 使用权重的反向传播来填充流水线气泡。
特别是,这是在论文中实现 ZB1P 调度。
- class torch.distributed.pipelining.schedules.ScheduleZBVZeroBubble(stages, n_microbatches, loss_fn=None, args_chunk_spec=None, kwargs_chunk_spec=None, output_merge_spec=None, stage_index_to_group_rank=None)[source][source]¶
零气泡调度 (ZBV 变体)。 有关详细信息,请参阅 https://arxiv.org/pdf/2401.10241 第 6 节。
此调度程序每个 rank 恰好需要两个阶段。
此调度程序将在稳态下对微批次的输入执行一次前向和一次反向传播,并支持每个 rank 多个阶段。 使用相对于权重的反向传播来填充流水线气泡。
仅当时间前向传播 == 时间反向传播输入 == 时间反向传播权重时,此 ZB-V 调度才具有“零气泡”属性。 实际上,这对于真实模型不太可能成立,因此可以为不等/不平衡时间实现贪婪调度程序。
- class torch.distributed.pipelining.schedules.PipelineScheduleSingle(stage, n_microbatches, loss_fn=None, args_chunk_spec=None, kwargs_chunk_spec=None, output_merge_spec=None)[source][source]¶
单阶段调度的基类。 实现 step 方法。 派生类应实现 _step_microbatches。