管道并行¶
注意
torch.distributed.pipelining
处于 alpha 状态,目前正在开发中。API 可能会有所变化。它已从 PiPPy 项目迁移。
为什么使用管道并行?¶
管道并行是深度学习的一种基本并行方式。它允许将模型的执行进行分区,以便多个微批次可以同时执行模型代码的不同部分。管道并行可以成为一种有效的技术,用于
大规模训练
带宽受限的集群
大型模型推理
上述场景有一个共同点,即每个设备的计算无法隐藏传统并行的通信,例如 FSDP 的权重全收集。
torch.distributed.pipelining
是什么?¶
虽然管道并行在扩展方面很有前景,但它通常难以实现,因为它需要除了模型权重之外,还要对模型的执行进行分区。执行分区通常需要对模型代码进行侵入式更改。复杂性的另一个方面来自于在考虑数据流依赖关系的情况下,在分布式环境中调度微批次。
pipelining
包提供了一个工具包,可以自动执行这些操作,从而可以在通用模型上轻松实现管道并行。
它由两部分组成:一个拆分前端和一个分布式运行时。拆分前端将您的模型代码按原样获取,将其拆分为“模型分区”,并捕获数据流关系。分布式运行时在不同的设备上并行执行管道阶段,处理微批次拆分、调度、通信和梯度传播等。
总的来说,pipelining
包提供了以下功能
基于简单规范拆分模型代码。
丰富地支持管道调度,包括 GPipe、1F1B、交织 1F1B 和循环 BFS,并提供编写自定义调度的基础结构。
对跨主机管道并行提供一流的支持,因为这是 PP 通常使用的地方(通过较慢的互连)。
与其他 PyTorch 并行技术(如数据并行 (DDP、FSDP) 或张量并行)的可组合性。该 TorchTitan 项目展示了 Llama 模型上的“3D 并行”应用程序。
步骤 1:构建 PipelineStage
¶
在我们使用 PipelineSchedule
之前,我们需要创建 PipelineStage
对象,这些对象包装了在该阶段运行的模型部分。 PipelineStage
负责分配通信缓冲区,并创建发送/接收操作以与其对等方进行通信。它管理中间缓冲区(例如,尚未被消耗的正向输出的缓冲区),并提供一个用于运行阶段模型的向后传播的实用程序。
一个 PipelineStage
需要知道阶段模型的输入和输出形状,以便它可以正确地分配通信缓冲区。形状必须是静态的,例如,在运行时,形状不能在步骤之间改变。如果运行时形状与预期形状不匹配,则会引发一个 PipeliningShapeError
类。在与其他并行机制组合或应用混合精度时,必须考虑这些技术,以便 PipelineStage
能够知道阶段模块在运行时的输出的正确形状(和数据类型)。
用户可以通过直接传入一个表示应在阶段运行的模型部分的 nn.Module
来构建 PipelineStage
实例。这可能需要更改原始模型代码。请参阅 选项 1:手动拆分模型 中的示例。
或者,分割前端可以使用图分区将您的模型自动分成一系列的nn.Module
。 此技术要求模型可以使用torch.Export
进行追踪。 结果nn.Module
与其他并行技术的可组合性处于实验阶段,可能需要一些变通方法。 如果用户无法轻松更改模型代码,则使用此前端可能更具吸引力。 有关更多信息,请参阅选项 2:自动分割模型。
步骤 2:使用PipelineSchedule
执行¶
现在,我们可以将PipelineStage
附加到管道计划,并使用输入数据运行计划。 以下是一个 GPipe 示例
from torch.distributed.pipelining import ScheduleGPipe
# Create a schedule
schedule = ScheduleGPipe(stage, n_microbatches)
# Input data (whole batch)
x = torch.randn(batch_size, in_dim, device=device)
# Run the pipeline with input `x`
# `x` will be divided into microbatches automatically
if rank == 0:
schedule.step(x)
else:
output = schedule.step()
请注意,以上代码需要为每个工作程序启动,因此我们使用启动器服务来启动多个进程
torchrun --nproc_per_node=2 example.py
分割模型的选项¶
选项 1:手动分割模型¶
要直接构建PipelineStage
,用户需要提供一个单个nn.Module
实例,该实例拥有相关的nn.Parameters
和nn.Buffers
,并定义一个forward()
方法来执行与该阶段相关的操作。 例如,Torchtitan 中定义的 Transformer 类的一个简化版本展示了构建一个易于分区的模型的模式。
class Transformer(nn.Module):
def __init__(self, model_args: ModelArgs):
super().__init__()
self.tok_embeddings = nn.Embedding(...)
# Using a ModuleDict lets us delete layers without affecting names,
# ensuring checkpoints will correctly save and load.
self.layers = torch.nn.ModuleDict()
for layer_id in range(model_args.n_layers):
self.layers[str(layer_id)] = TransformerBlock(...)
self.output = nn.Linear(...)
def forward(self, tokens: torch.Tensor):
# Handling layers being 'None' at runtime enables easy pipeline splitting
h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens
for layer in self.layers.values():
h = layer(h, self.freqs_cis)
h = self.norm(h) if self.norm else h
output = self.output(h).float() if self.output else h
return output
通过首先初始化整个模型(使用元设备来避免 OOM 错误),删除该阶段不需要的层,然后创建一个包装模型的 PipelineStage,可以轻松配置以这种方式定义的模型。 例如
with torch.device("meta"):
assert num_stages == 2, "This is a simple 2-stage example"
# we construct the entire model, then delete the parts we do not need for this stage
# in practice, this can be done using a helper function that automatically divides up layers across stages.
model = Transformer()
if stage_index == 0:
# prepare the first stage model
del model.layers["1"]
model.norm = None
model.output = None
elif stage_index == 1:
# prepare the second stage model
model.tok_embeddings = None
del model.layers["0"]
from torch.distributed.pipelining import PipelineStage
stage = PipelineStage(
model,
stage_index,
num_stages,
device,
input_args=example_input_microbatch,
)
PipelineStage
需要一个示例参数input_args
,它表示该阶段的运行时输入,该输入将是一个微批次的输入数据。 此参数将通过阶段模块的 forward 方法传递,以确定通信所需的输入和输出形状。
当与其他数据或模型并行技术组合时,如果模型块的输出形状/数据类型将受到影响,则可能还需要output_args
。
选项 2:自动分割模型¶
如果您有一个完整的模型,并且不想花费时间将其修改为一系列“模型分区”,那么pipeline
API 可以提供帮助。 以下是一个简短的示例
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.emb = torch.nn.Embedding(10, 3)
self.layers = torch.nn.ModuleList(
Layer() for _ in range(2)
)
self.lm = LMHead()
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.emb(x)
for layer in self.layers:
x = layer(x)
x = self.lm(x)
return x
如果我们打印模型,我们可以看到多个层次结构,这使得手动分割变得很困难
Model(
(emb): Embedding(10, 3)
(layers): ModuleList(
(0-1): 2 x Layer(
(lin): Linear(in_features=3, out_features=3, bias=True)
)
)
(lm): LMHead(
(proj): Linear(in_features=3, out_features=3, bias=True)
)
)
让我们看看pipeline
API 是如何工作的
from torch.distributed.pipelining import pipeline, SplitPoint
# An example micro-batch input
x = torch.LongTensor([1, 2, 4, 5])
pipe = pipeline(
module=mod,
mb_args=(x,),
split_spec={
"layers.1": SplitPoint.BEGINNING,
}
)
pipeline
API 会根据一个split_spec
分割您的模型,其中SplitPoint.BEGINNING
表示在forward
函数中执行某个子模块之前添加一个分割点,类似地,SplitPoint.END
表示在执行之后添加一个分割点。
如果我们print(pipe)
,我们可以看到
GraphModule(
(submod_0): GraphModule(
(emb): InterpreterModule()
(layers): Module(
(0): InterpreterModule(
(lin): InterpreterModule()
)
)
)
(submod_1): GraphModule(
(layers): Module(
(1): InterpreterModule(
(lin): InterpreterModule()
)
)
(lm): InterpreterModule(
(proj): InterpreterModule()
)
)
)
def forward(self, x):
submod_0 = self.submod_0(x); x = None
submod_1 = self.submod_1(submod_0); submod_0 = None
return (submod_1,)
“模型分区”由子模块(submod_0
,submod_1
)表示,每个子模块都使用原始模型操作、权重和层次结构重建。 此外,一个“根级”forward
函数被重建以捕获这些分区之间的数据流。 这种数据流将在以后由管道运行时以分布式方式重放。
Pipe
对象提供了一种检索“模型分区”的方法
stage_mod : nn.Module = pipe.get_stage_module(stage_idx)
返回的stage_mod
是一个nn.Module
,您可以用它来创建优化器、保存或加载检查点,或应用其他并行性。
Pipe
还允许您在给定一个ProcessGroup
的情况下在设备上创建一个分布式阶段运行时
stage = pipe.build_stage(stage_idx, device, group)
或者,如果您想在对stage_mod
进行一些修改之后再构建阶段运行时,可以使用build_stage
API 的函数版本。 例如
from torch.distributed.pipelining import build_stage
from torch.nn.parallel import DistributedDataParallel
dp_mod = DistributedDataParallel(stage_mod)
info = pipe.info()
stage = build_stage(dp_mod, stage_idx, info, device, group)
注意
pipeline
前端使用跟踪器(torch.export
)将您的模型捕获到一个单一图中。 如果您的模型不是全图可追踪的,您可以使用以下手动前端。
Hugging Face 示例¶
在此包最初创建的PiPPy 存储库中,我们保留了基于未修改的 Hugging Face 模型的示例。 请参阅examples/huggingface 目录。
示例包括
技术深入分析¶
pipeline
API 如何分割模型?¶
首先,pipeline
API 通过跟踪模型将我们的模型转换为有向无环图 (DAG)。 它使用torch.export
(一个 PyTorch 2 全图捕获工具)来跟踪模型。
然后,它将一个阶段需要的操作和参数分组到一个重建的子模块中:submod_0
,submod_1
,…
与传统的子模块访问方法(如Module.children()
)不同,pipeline
API 不仅切断模型的模块结构,还切断模型的forward函数。
这是必要的,因为像Module.children()
这样的模型结构只捕获Module.__init__()
期间的信息,而不会捕获任何关于Module.forward()
的信息。 换句话说,Module.children()
缺少以下对管道化至关重要的方面的信息
forward
中子模块的执行顺序子模块之间的激活流
子模块之间是否存在任何函数运算符(例如,
relu
或add
操作将不会被Module.children()
捕获)。
相反,pipeline
API 确保真正保留了forward
行为。 它还捕获了分区之间的激活流,帮助分布式运行时在没有人为干预的情况下进行正确的发送/接收调用。
pipeline
API 的另一个灵活性是分割点可以位于模型层次结构中的任意级别。 在分割分区中,与该分区相关的原始模型层次结构将被重建,您无需为此付出任何代价。 因此,指向子模块或参数的完全限定名称 (FQN) 仍然有效,依赖 FQN 的服务(例如 FSDP、TP 或检查点)仍然可以使用您分割的模块运行,代码更改几乎为零。
实现您自己的计划¶
您可以通过扩展以下两个类之一来实现您自己的管道计划
PipelineScheduleSingle
PipelineScheduleMulti
PipelineScheduleSingle
用于将仅一个阶段分配给每个等级的计划。 PipelineScheduleMulti
用于将多个阶段分配给每个等级的计划。
例如,ScheduleGPipe
和Schedule1F1B
是PipelineScheduleSingle
的子类。 而ScheduleFlexibleInterleaved1F1B
、ScheduleInterleaved1F1B
和ScheduleLoopedBFS
是PipelineScheduleMulti
的子类。
日志记录¶
您可以使用来自 [torch._logging](https://pytorch.ac.cn/docs/main/logging.html#module-torch._logging) 的TORCH_LOGS环境变量来打开其他日志记录
TORCH_LOGS=+pp 将显示logging.DEBUG消息及其以上所有级别。
TORCH_LOGS=pp 将显示logging.INFO消息及其以上。
TORCH_LOGS=-pp 将显示logging.WARNING消息及其以上。
API 参考¶
模型分割 API¶
以下 API 集合将您的模型转换为管道表示形式。
- torch.distributed.pipelining.pipeline(module, mb_args, mb_kwargs=None, split_spec=None, split_policy=None)[source]¶
根据规范分割模块。
有关更多详细信息,请参阅Pipe。
- 参数
module (Module) – 要拆分的模块。
mb_kwargs (Optional[Dict[str, Any]]) – 微批形式的示例关键字输入。 (默认值: None)
split_spec (Optional[Dict[str, SplitPoint]]) – 使用子模块名称作为分割标记的字典。 (默认值: None)
split_policy (Optional[Callable[[GraphModule], GraphModule]]) – 用于拆分模块的策略。 (默认值: None)
- 返回值类型
类Pipe 的管道表示。
微批实用程序¶
管道阶段¶
- class torch.distributed.pipelining.stage.PipelineStage(submodule, stage_index, num_stages, device, input_args, output_args=None, group=None, dw_builder=None)[source]¶
在管道并行设置中表示管道阶段的类。这个类是通过提供示例输入(以及可选输出)手动创建的,而不是从 pipeline() 输出的 PipelineStage 类。这个类扩展了 _PipelineStageBase 类,并且可以类似地用于 PipelineScheule 中。
- 参数
submodule (nn.Module) – 此阶段包装的 PyTorch 模块。
stage_index (int) – 此阶段的 ID。
num_stages (int) – 阶段总数。
device (torch.device) – 此阶段所在的设备。
input_args (Union[torch.Tensor, Tuple[torch.tensor]], optional) – 子模块的输入参数。
output_args (Union[torch.Tensor, Tuple[torch.tensor]], optional) – 子模块的输出参数。
group (dist.ProcessGroup, optional) – 分布式训练的进程组。如果为 None,则为默认组。
dw_builder (Optional[Callable[[], Callable[[...], None]]]) – TODO 清理注释
- torch.distributed.pipelining.stage.build_stage(stage_module, stage_index, pipe_info, device, group=None)[source]¶
给定一个要由此阶段包装的 stage_module 和管道信息,创建一个管道阶段。
- 参数
stage_module (torch.nn.Module) – 要由此阶段包装的模块
stage_index (int) – 此阶段在管道中的索引
pipe_info (PipeInfo) – 有关管道的的信息,可以通过 pipe.info() 检索
device (torch.device) – 此阶段要使用的设备
group (Optional[dist.ProcessGroup]) – 此阶段要使用的进程组
- 返回值
一个可以与 PipelineSchedules 一起运行的流水线阶段。
- 返回值类型
_PipelineStage
流水线调度¶
- class torch.distributed.pipelining.schedules.ScheduleGPipe(stage, n_microbatches, loss_fn=None, args_chunk_spec=None, kwargs_chunk_spec=None, output_merge_spec=None)[source]¶
GPipe 调度。 将以填充-清空的方式遍历所有微批次。
- class torch.distributed.pipelining.schedules.Schedule1F1B(stage, n_microbatches, loss_fn=None, args_chunk_spec=None, kwargs_chunk_spec=None, output_merge_spec=None)[source]¶
1F1B 调度。 将在稳态下对微批次执行一次前向和一次反向。
- class torch.distributed.pipelining.schedules.ScheduleFlexibleInterleaved1F1B(stages, n_microbatches, loss_fn=None, args_chunk_spec=None, kwargs_chunk_spec=None, output_merge_spec=None, enable_zero_bubble=False)[source]¶
灵活交织的 1F1B 调度。
该调度与交织的 1F1B 调度非常相似。 它的不同之处在于放松了 num_microbatch % pp_size == 0 的要求。 使用 flex_pp 调度,我们将有 num_rounds = max(1, n_microbatches // pp_group_size),只要 n_microbatches % num_rounds 为 0,它就起作用。 举几个例子,支持
pp_group_size = 4, n_microbatches = 10。 我们将有 num_rounds = 2 且 n_microbatches % 2 为 0。
pp_group_size = 4, n_microbatches = 3。 我们将有 num_rounds = 1 且 n_microbatches % 1 为 0。
当 enable_zero_bubble 为 True 时,我们将使用 https://openreview.net/pdf?id=tuzTN0eIO5 中的 ZB1P 调度。
- class torch.distributed.pipelining.schedules.ScheduleInterleaved1F1B(stages, n_microbatches, loss_fn=None, args_chunk_spec=None, kwargs_chunk_spec=None, output_merge_spec=None)[source]¶
交织的 1F1B 调度。 有关详细信息,请参阅 https://arxiv.org/pdf/2104.04473。 将在稳态下对微批次执行一次前向和一次反向,并支持每个等级上的多个阶段。 当微批次准备好用于多个本地阶段时,交织的 1F1B 会优先考虑较早的微批次(也称为“深度优先”)。
- class torch.distributed.pipelining.schedules.ScheduleLoopedBFS(stages, n_microbatches, loss_fn=None, output_merge_spec=None)[source]¶
广度优先流水线并行。 有关详细信息,请参阅 https://arxiv.org/abs/2211.05953。 与交织的 1F1B 类似,循环 BFS 支持每个等级上的多个阶段。 不同之处在于,当微批次准备好用于多个本地阶段时,循环 BFS 将优先考虑较早的阶段,一次运行所有可用的微批次。
- class torch.distributed.pipelining.schedules.ScheduleInterleavedZeroBubble(stages, n_microbatches, loss_fn=None, args_chunk_spec=None, kwargs_chunk_spec=None, output_merge_spec=None)[source]¶
交织零气泡调度。 有关详细信息,请参阅 https://arxiv.org/pdf/2401.10241。 将在稳态下对微批次的输入执行一次前向和一次反向,并支持每个等级上的多个阶段。 使用权重的反向来填充流水线气泡。
- class torch.distributed.pipelining.schedules.PipelineScheduleSingle(stage, n_microbatches, loss_fn=None, args_chunk_spec=None, kwargs_chunk_spec=None, output_merge_spec=None)[source]¶
单阶段调度的基类。 实现 step 方法。 派生类应实现 _step_microbatches。