快捷方式

流水线并行

注意

torch.distributed.pipelining 目前处于 alpha 阶段,仍在开发中。API 可能会发生变化。它从 PiPPy 项目迁移而来。

为何选择流水线并行?

流水线并行 (Pipeline Parallelism) 是深度学习的**基本**并行化方法之一。它允许模型**执行**被分割,以便多个**微批次**可以并发执行模型代码的不同部分。流水线并行是一种有效的技术,适用于:

  • 大规模训练

  • 带宽受限的集群

  • 大型模型推理

上述场景有一个共同点:每个设备的计算无法隐藏传统并行化的通信开销,例如 FSDP 的权重 all-gather。

什么是 torch.distributed.pipelining

尽管流水线化在扩展方面很有前景,但实现起来通常很困难,因为它不仅需要分割模型权重,还需要**分割模型的执行**。分割执行通常需要对您的模型进行侵入性修改。复杂性的另一个方面在于**在分布式环境中调度微批次**,同时**考虑数据流依赖性**。

pipelining 包提供了一个工具包,可以**自动**完成上述任务,从而可以在**通用**模型上轻松实现流水线并行化。

它由两部分组成:一个**分割前端**和一个**分布式运行时**。分割前端接收您原封不动的模型代码,将其分割成“模型分区”,并捕获数据流关系。分布式运行时在不同设备上并行执行流水线阶段,处理微批次分割、调度、通信和梯度传播等任务。

总的来说,pipelining 包提供了以下特性:

  • 基于简单规范分割模型代码。

  • 丰富支持流水线调度,包括 GPipe、1F1B、交错 1F1B (Interleaved 1F1B) 和循环 BFS (Looped BFS),并提供基础设施用于编写自定义调度。

  • 一流地支持跨主机流水线并行化,因为 PP 通常在这种场景下使用(通过较慢的互连)。

  • 可与其他 PyTorch 并行化技术(如数据并行 (DDP, FSDP) 或张量并行)组合使用。TorchTitan 项目展示了 Llama 模型上的“3D 并行”应用。

步骤 1: 构建 PipelineStage

在使用 PipelineSchedule 之前,我们需要创建 PipelineStage 对象,这些对象封装了在该阶段运行的模型部分。PipelineStage 负责分配通信缓冲区并创建发送/接收操作来与其对等方通信。它管理中间缓冲区,例如尚未被消耗的前向传播输出,并提供一个工具来运行阶段模型的反向传播。

PipelineStage 需要知道阶段模型的输入和输出形状,以便正确分配通信缓冲区。形状必须是静态的,例如,在运行时形状不能一步一步地改变。如果运行时形状与预期形状不匹配,将引发 PipeliningShapeError 类错误。与其他并行化技术组合或应用混合精度时,必须考虑这些技术,以便 PipelineStage 在运行时知道阶段模块输出的正确形状(和 dtype)。

用户可以通过传入一个 nn.Module 实例直接构建 PipelineStage 实例,该实例代表应在该阶段运行的模型部分。这可能需要修改原始模型代码。请参见Option 1: 手动分割模型中的示例。

或者,分割前端可以使用图分区自动将模型分割成一系列 nn.Module。这种技术要求模型可以使用 torch.Export 进行跟踪(traceable)。由此生成的 nn.Module 与其他并行化技术的组合性是实验性的,可能需要一些变通方法。如果用户无法轻松更改模型代码,使用此前端可能更具吸引力。有关更多信息,请参见Option 2: 自动分割模型

步骤 2: 使用 PipelineSchedule 执行

现在我们可以将 PipelineStage 连接到流水线调度器,并使用输入数据运行调度器。以下是一个 GPipe 示例:

from torch.distributed.pipelining import ScheduleGPipe

# Create a schedule
schedule = ScheduleGPipe(stage, n_microbatches)

# Input data (whole batch)
x = torch.randn(batch_size, in_dim, device=device)

# Run the pipeline with input `x`
# `x` will be divided into microbatches automatically
if rank == 0:
    schedule.step(x)
else:
    output = schedule.step()

请注意,上述代码需要在每个工作进程中启动,因此我们使用启动器服务来启动多个进程。

torchrun --nproc_per_node=2 example.py

分割模型的选项

Option 1: 手动分割模型

要直接构建 PipelineStage,用户需要负责提供一个单独的 nn.Module 实例,该实例拥有相关的 nn.Parametersnn.Buffers,并定义一个 forward() 方法来执行与该阶段相关的操作。例如,Torchtitan 中定义的 Transformer 类的一个精简版本展示了一种构建易于分区模型的方式。

class Transformer(nn.Module):
    def __init__(self, model_args: ModelArgs):
        super().__init__()

        self.tok_embeddings = nn.Embedding(...)

        # Using a ModuleDict lets us delete layers without affecting names,
        # ensuring checkpoints will correctly save and load.
        self.layers = torch.nn.ModuleDict()
        for layer_id in range(model_args.n_layers):
            self.layers[str(layer_id)] = TransformerBlock(...)

        self.output = nn.Linear(...)

    def forward(self, tokens: torch.Tensor):
        # Handling layers being 'None' at runtime enables easy pipeline splitting
        h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens

        for layer in self.layers.values():
            h = layer(h, self.freqs_cis)

        h = self.norm(h) if self.norm else h
        output = self.output(h).float() if self.output else h
        return output

以这种方式定义的模型可以很容易地按阶段配置:首先初始化整个模型(使用 meta-device 避免 OOM 错误),删除该阶段不需要的层,然后创建一个封装该模型的 PipelineStage。例如:

with torch.device("meta"):
    assert num_stages == 2, "This is a simple 2-stage example"

    # we construct the entire model, then delete the parts we do not need for this stage
    # in practice, this can be done using a helper function that automatically divides up layers across stages.
    model = Transformer()

    if stage_index == 0:
        # prepare the first stage model
        del model.layers["1"]
        model.norm = None
        model.output = None

    elif stage_index == 1:
        # prepare the second stage model
        model.tok_embeddings = None
        del model.layers["0"]

    from torch.distributed.pipelining import PipelineStage
    stage = PipelineStage(
        model,
        stage_index,
        num_stages,
        device,
    )

与其他数据或模型并行化技术组合时,如果模型块的输出形状/dtype 将受到影响,可能还需要 output_args

Option 2: 自动分割模型

如果您有一个完整的模型,并且不想花费时间将其修改成一系列“模型分区”,那么 pipeline API 可以提供帮助。以下是一个简要示例:

class Model(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.emb = torch.nn.Embedding(10, 3)
        self.layers = torch.nn.ModuleList(
            Layer() for _ in range(2)
        )
        self.lm = LMHead()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.emb(x)
        for layer in self.layers:
            x = layer(x)
        x = self.lm(x)
        return x

如果我们打印模型,可以看到多层层次结构,这使得手动分割变得困难

Model(
  (emb): Embedding(10, 3)
  (layers): ModuleList(
    (0-1): 2 x Layer(
      (lin): Linear(in_features=3, out_features=3, bias=True)
    )
  )
  (lm): LMHead(
    (proj): Linear(in_features=3, out_features=3, bias=True)
  )
)

让我们看看 pipeline API 是如何工作的

from torch.distributed.pipelining import pipeline, SplitPoint

# An example micro-batch input
x = torch.LongTensor([1, 2, 4, 5])

pipe = pipeline(
    module=mod,
    mb_args=(x,),
    split_spec={
        "layers.1": SplitPoint.BEGINNING,
    }
)

pipeline API 根据 split_spec 分割您的模型,其中 SplitPoint.BEGINNING 表示在 forward 函数中某个子模块执行*之前*添加分割点,类似地,SplitPoint.END 表示在*之后*添加分割点。

如果我们 print(pipe),可以看到

GraphModule(
  (submod_0): GraphModule(
    (emb): InterpreterModule()
    (layers): Module(
      (0): InterpreterModule(
        (lin): InterpreterModule()
      )
    )
  )
  (submod_1): GraphModule(
    (layers): Module(
      (1): InterpreterModule(
        (lin): InterpreterModule()
      )
    )
    (lm): InterpreterModule(
      (proj): InterpreterModule()
    )
  )
)

def forward(self, x):
    submod_0 = self.submod_0(x);  x = None
    submod_1 = self.submod_1(submod_0);  submod_0 = None
    return (submod_1,)

“模型分区”由子模块(submod_0, submod_1)表示,每个子模块都使用原始模型的运算、权重和层次结构重建。此外,还重建了一个“根级别”的 forward 函数,以捕获这些分区之间的数据流。这种数据流稍后将由流水线运行时以分布式方式重放。

Pipe 对象提供了一个方法来检索“模型分区”:

stage_mod : nn.Module = pipe.get_stage_module(stage_idx)

返回的 stage_mod 是一个 nn.Module,您可以使用它来创建优化器、保存或加载检查点,或应用其他并行化技术。

Pipe 还允许您在给定 ProcessGroup 的设备上创建一个分布式阶段运行时:

stage = pipe.build_stage(stage_idx, device, group)

或者,如果您想在修改 stage_mod 后再构建阶段运行时,可以使用 build_stage API 的函数版本。例如:

from torch.distributed.pipelining import build_stage
from torch.nn.parallel import DistributedDataParallel

dp_mod = DistributedDataParallel(stage_mod)
info = pipe.info()
stage = build_stage(dp_mod, stage_idx, info, device, group)

注意

pipeline 前端使用一个跟踪器(torch.export)将您的模型捕获到一个单一图中。如果您的模型不能进行完整图捕获(full-graph’able),您可以使用下面的手动前端。

Hugging Face 示例

该包最初创建于 PiPPy 仓库,我们在其中保留了基于未经修改的 Hugging Face 模型的示例。请参见 examples/huggingface 目录。

示例包括:

技术深入

pipeline API 如何分割模型?

首先,pipeline API 通过跟踪模型将其转化为有向无环图 (DAG)。它使用 torch.export(一个 PyTorch 2 完整图捕获工具)来跟踪模型。

然后,它将一个阶段所需的操作和参数分组到一个重建的子模块中:submod_0, submod_1, …

Module.children() 等传统的子模块访问方法不同,pipeline API 不仅会剪切您模型的模块结构,还会剪切模型的 forward 函数。

这是必要的,因为 Module.children() 等模型结构仅捕获 Module.__init__() 期间的信息,而不捕获有关 Module.forward() 的任何信息。换句话说,Module.children() 缺乏对流水线化至关重要的以下方面信息:

  • 子模块在 forward 中的执行顺序

  • 子模块之间的激活流

  • 子模块之间是否存在函数运算符(例如,Module.children() 不会捕获 reluadd 运算)。

相反,pipeline API 确保 forward 行为被真正保留。它还捕获分区之间的激活流,帮助分布式运行时无需人工干预即可进行正确的发送/接收调用。

pipeline API 的另一个灵活性在于,分割点可以在您的模型层次结构的任意级别上。在分割的分区中,与该分区相关的原始模型层次结构将被重建,对您而言没有额外开销。因此,指向子模块或参数的完全限定名称 (FQN) 仍然有效,并且依赖于 FQN 的服务(例如 FSDP、TP 或检查点)仍然可以与您分区后的模块一起运行,几乎无需修改代码。

实现您自己的调度器

您可以通过扩展以下两个类中的一个来实现您自己的流水线调度器:

  • PipelineScheduleSingle

  • PipelineScheduleMulti

PipelineScheduleSingle 用于为每个 rank 仅分配一个阶段的调度器。PipelineScheduleMulti 用于为每个 rank 分配多个阶段的调度器。

例如,ScheduleGPipeSchedule1F1BPipelineScheduleSingle 的子类。而 ScheduleInterleaved1F1BScheduleLoopedBFSScheduleInterleavedZeroBubbleScheduleZBVZeroBubblePipelineScheduleMulti 的子类。

日志记录

您可以使用 torch._logging 中的 TORCH_LOGS 环境变量来开启额外的日志记录:

  • TORCH_LOGS=+pp 将显示 logging.DEBUG 级别及以上的所有消息。

  • TORCH_LOGS=pp 将显示 logging.INFO 级别及以上的所有消息。

  • TORCH_LOGS=-pp 将显示 logging.WARNING 级别及以上的所有消息。

API 参考

模型分割 API

以下一系列 API 将您的模型转换为流水线表示形式。

class torch.distributed.pipelining.SplitPoint(value)[源代码][源代码]

枚举,表示子模块执行中可以发生分割的点。 :ivar BEGINNING: 表示在 forward 函数中某个子模块执行*之前*添加分割点。 :ivar END: 表示在 forward 函数中某个子模块执行*之后*添加分割点。

torch.distributed.pipelining.pipeline(module, mb_args, mb_kwargs=None, split_spec=None, split_policy=None)[源代码][源代码]

根据规范分割模块。

有关更多详细信息,请参见 Pipe

参数
返回类型

Pipe 的流水线表示。

class torch.distributed.pipelining.Pipe(split_gm, num_stages, has_loss_and_backward, loss_spec)[源代码][源代码]
torch.distributed.pipelining.pipe_split()[源代码][源代码]

pipe_split 是一个特殊运算符,用于标记模块中阶段之间的边界。它用于将模块分割成阶段。如果您的标注模块在 eagerly 模式下运行,它是一个无操作 (no-op)。

示例

>>> def forward(self, x):
>>>     x = torch.mm(x, self.mm_param)
>>>     x = torch.relu(x)
>>>     pipe_split()
>>>     x = self.lin(x)
>>>     return x

以上示例将被分割成两个阶段。

微批次实用工具

class torch.distributed.pipelining.microbatch.TensorChunkSpec(split_dim)[源代码][源代码]

用于指定输入分块的类

torch.distributed.pipelining.microbatch.split_args_kwargs_into_chunks(args, kwargs, chunks, args_chunk_spec=None, kwargs_chunk_spec=None)[源代码][源代码]

给定一系列 args 和 kwargs,根据它们各自的分块规范将它们分割成多个块。

参数
返回

分片 args 列表 kwargs_split:分片 kwargs 列表

返回类型

args_split

torch.distributed.pipelining.microbatch.merge_chunks(chunks, chunk_spec)[source][source]

给定一个块列表,根据块规范将它们合并为单个值。

参数
  • chunks (list[Any]) – 块列表

  • chunk_spec – 这些块的分块规范

返回

合并后的值

返回类型

流水线阶段

class torch.distributed.pipelining.stage.PipelineStage(submodule, stage_index, num_stages, device, input_args=None, output_args=None, group=None, dw_builder=None)[source][source]

一个类,表示流水线并行设置中的一个流水线阶段。

PipelineStage 假设模型是顺序划分的,即模型被分割成块,其中一个块的输出作为下一个块的输入,没有跳跃连接。

PipelineStage 通过按线性顺序将 stage0 的输出传播到 stage1 等,自动执行运行时 shape/dtype 推断。要绕过 shape 推断,请将 input_argsoutput_args 传递给每个 PipelineStage 实例。

参数
  • submodule (nn.Module) – 此阶段封装的 PyTorch 模块。

  • stage_index (int) – 此阶段的 ID。

  • num_stages (int) – 阶段总数。

  • device (torch.device) – 此阶段所在的设备。

  • input_args (Union[torch.Tensor, Tuple[torch.tensor]], optional) – 子模块的输入参数。

  • output_args (Union[torch.Tensor, Tuple[torch.tensor]], optional) – 子模块的输出参数。

  • group (dist.ProcessGroup, optional) – 分布式训练的进程组。如果为 None,则使用默认组。

  • dw_builder (Optional[Callable[[], Callable[..., None]]) – 如果提供,dw_builder 将构建一个新的 dw_runner 函数,该函数用于 F, I, W (前向, 输入, 权重) 零气泡调度中的 W 操作(输入权重)。

torch.distributed.pipelining.stage.build_stage(stage_module, stage_index, pipe_info, device, group=None)[source][source]

根据要由此阶段封装的 stage_module 和流水线信息创建一个流水线阶段。

参数
  • stage_module (torch.nn.Module) – 要由此阶段封装的模块

  • stage_index (int) – 此阶段在流水线中的索引

  • pipe_info (PipeInfo) – 关于流水线的信息,可以通过 pipe.info() 获取

  • device (torch.device) – 此阶段要使用的设备

  • group (Optional[dist.ProcessGroup]) – 此阶段要使用的进程组

返回

一个可以与 PipelineSchedules 一起运行的流水线阶段。

返回类型

_PipelineStage

流水线调度器

class torch.distributed.pipelining.schedules.ScheduleGPipe(stage, n_microbatches, loss_fn=None, args_chunk_spec=None, kwargs_chunk_spec=None, output_merge_spec=None, scale_grads=True)[source][source]

GPipe 调度器。将以填充-排空方式处理所有微批次。

class torch.distributed.pipelining.schedules.Schedule1F1B(stage, n_microbatches, loss_fn=None, args_chunk_spec=None, kwargs_chunk_spec=None, output_merge_spec=None, scale_grads=True)[source][source]

1F1B 调度器。将在稳态下对微批次执行一次前向传播和一次反向传播。

class torch.distributed.pipelining.schedules.ScheduleInterleaved1F1B(stages, n_microbatches, loss_fn=None, args_chunk_spec=None, kwargs_chunk_spec=None, output_merge_spec=None, scale_grads=True)[source][source]

交错式 1F1B 调度器。详情请参阅 https://arxiv.org/pdf/2104.04473。将在稳态下对微批次执行一次前向传播和一次反向传播,并支持每个 rank 多个阶段。当多个本地阶段的微批次准备就绪时,交错式 1F1B 优先处理较早的微批次(也称为“深度优先”)。

这个调度器与原始论文基本相似。不同之处在于它放宽了 num_microbatch % pp_size == 0 的要求。使用 flex_pp 调度器,我们将得到 num_rounds = max(1, n_microbatches // pp_group_size),并且只要 n_microbatches % num_rounds 等于 0 即可工作。举例说明,支持

  1. pp_group_size = 4, n_microbatches = 10。我们将得到 num_rounds = 2,并且 n_microbatches % 2 等于 0。

  2. pp_group_size = 4, n_microbatches = 3。我们将得到 num_rounds = 1,并且 n_microbatches % 1 等于 0。

class torch.distributed.pipelining.schedules.ScheduleLoopedBFS(stages, n_microbatches, loss_fn=None, output_merge_spec=None, scale_grads=True)[source][source]

广度优先流水线并行。详情请参阅 https://arxiv.org/abs/2211.05953。与交错式 1F1B 类似,Looped BFS 支持每个 rank 多个阶段。不同之处在于,当多个本地阶段的微批次准备就绪时,Looped BFS 将优先处理较早的阶段,一次性运行所有可用的微批次。

class torch.distributed.pipelining.schedules.ScheduleInterleavedZeroBubble(stages, n_microbatches, loss_fn=None, args_chunk_spec=None, kwargs_chunk_spec=None, output_merge_spec=None, scale_grads=True)[source][source]

交错式零气泡调度器。详情请参阅 https://arxiv.org/pdf/2401.10241。将在稳态下对微批次的输入执行一次前向传播和一次反向传播,并支持每个 rank 多个阶段。利用对权重的反向传播来填充流水线气泡。

特别地,这实现了论文中的 ZB1P 调度器。

class torch.distributed.pipelining.schedules.ScheduleZBVZeroBubble(stages, n_microbatches, loss_fn=None, args_chunk_spec=None, kwargs_chunk_spec=None, output_merge_spec=None, scale_grads=True)[source][source]

零气泡调度器(ZBV 变体)。详情请参阅 https://arxiv.org/pdf/2401.10241 的第 6 节。

这个调度器要求每个 rank 恰好有两个阶段。

这个调度器将在稳态下对微批次的输入执行一次前向传播和一次反向传播,并支持每个 rank 多个阶段。利用相对于权重的反向传播来填充流水线气泡。

只有当前向时间 == 输入反向时间 == 权重反向时间时,这种 ZB-V 调度器才具有“零气泡”特性。在实践中,对于真实模型这不太可能成立,因此对于不相等/不平衡的时间可以另外实现一个贪心调度器。

class torch.distributed.pipelining.schedules.PipelineScheduleSingle(stage, n_microbatches, loss_fn=None, args_chunk_spec=None, kwargs_chunk_spec=None, output_merge_spec=None, scale_grads=True)[source][source]

单阶段调度器的基类。实现了 step 方法。派生类应该实现 _step_microbatches 方法。

梯度根据 scale_grads 参数按微批次数量进行缩放,默认为 True。此设置应与您的 loss_fn 配置匹配,loss_fn 可能对损失进行平均(scale_grads=True)或求和(scale_grads=False)。

step(*args, target=None, losses=None, **kwargs)[source][source]

使用完整批次输入运行流水线调度的一个迭代。将自动把输入分块成微批次,并根据调度实现遍历微批次。

args: 模型的 位置参数(与非流水线情况相同)。 kwargs: 模型的 关键词参数(与非流水线情况相同)。 target: 损失函数的目标。 losses: 一个列表,用于存储每个微批次的损失。

class torch.distributed.pipelining.schedules.PipelineScheduleMulti(stages, n_microbatches, loss_fn=None, args_chunk_spec=None, kwargs_chunk_spec=None, output_merge_spec=None, use_full_backward=None, scale_grads=True)[source][source]

多阶段调度器的基类。实现了 step 方法。

梯度根据 scale_grads 参数按微批次数量进行缩放,默认为 True。此设置应与您的 loss_fn 配置匹配,loss_fn 可能对损失进行平均(scale_grads=True)或求和(scale_grads=False)。

step(*args, target=None, losses=None, **kwargs)[source][source]

使用完整批次输入运行流水线调度的一个迭代。将自动把输入分块成微批次,并根据调度实现遍历微批次。

args: 模型的 位置参数(与非流水线情况相同)。 kwargs: 模型的 关键词参数(与非流水线情况相同)。 target: 损失函数的目标。 losses: 一个列表,用于存储每个微批次的损失。

文档

访问 PyTorch 的完整开发者文档

查看文档

教程

获取针对初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题解答

查看资源