快捷方式

流水线并行

注意

torch.distributed.pipelining 目前处于 alpha 状态,仍在开发中。API 可能会发生变化。它已从 PiPPy 项目迁移而来。

为什么要进行流水线并行?

流水线并行是深度学习的基本并行方式之一。它允许对模型的执行进行分区,以便多个微批次可以同时执行模型代码的不同部分。流水线并行对于以下情况可能是一种有效的技术

  • 大规模训练

  • 带宽受限的集群

  • 大型模型推理。

上述情况都有一个共同点,即每个设备的计算量无法掩盖传统并行方式的通信开销,例如 FSDP 的权重全收集。

什么是 torch.distributed.pipelining

虽然流水线并行在扩展方面很有前景,但它通常难以实现,因为它需要在模型权重之外对模型的执行进行分区。执行分区通常需要对模型代码进行侵入式更改。另一个复杂性来自于在分布式环境中调度微批次,同时还要考虑数据流依赖性

pipelining 包提供了一个工具包,可以自动完成上述工作,从而允许在通用模型上轻松实现流水线并行。

它由两部分组成:拆分前端分布式运行时。拆分前端按原样获取模型代码,将其拆分为“模型分区”,并捕获数据流关系。分布式运行时在不同设备上并行执行流水线阶段,处理微批次拆分、调度、通信和梯度传播等。

总的来说,pipelining 包提供了以下功能

  • 基于简单规范的模型代码拆分。

  • 对流水线调度提供丰富的支持,包括 GPipe、1F1B、交错 1F1B 和循环 BFS,并提供编写自定义调度的基础设施。

  • 对跨主机流水线并行提供一流的支持,因为这是 PP 通常使用的地方(通过较慢的互连)。

  • 与其他 PyTorch 并行技术(如数据并行(DDP、FSDP)或张量并行)的可组合性。TorchTitan 项目展示了 Llama 模型上的“3D 并行”应用。

步骤 1:构建 PipelineStage 以供执行

在使用 PipelineSchedule 之前,我们需要创建 PipelineStage 对象,这些对象包装了在该阶段运行的模型部分。PipelineStage 负责分配通信缓冲区并创建发送/接收操作以与其对等方通信。它管理中间缓冲区,例如尚未使用的正向输出,并提供用于运行阶段模型反向传播的实用程序。

PipelineStage 需要知道阶段模型的输入和输出形状,以便它可以正确分配通信缓冲区。形状必须是静态的,例如,在运行时,形状不能在各个步骤之间发生变化。如果运行时形状与预期形状不匹配,则会引发 PipeliningShapeError 类。当与其他并行机制组合或应用混合精度时,必须考虑这些技术,以便 PipelineStage 知道运行时阶段模块输出的正确形状(和数据类型)。

用户可以通过传入一个表示应该在阶段上运行的模型部分的 nn.Module 来直接构造一个 PipelineStage 实例。这可能需要更改原始模型代码。请参阅 选项 1:手动拆分模型 中的示例。

或者,拆分前端可以使用图分区将模型自动拆分为一系列 nn.Module。此技术要求模型可以使用 torch.Export 进行跟踪。生成的 nn.Module 与其他并行技术的组合性尚处于实验阶段,可能需要一些解决方法。如果用户无法轻松更改模型代码,则使用此前端可能更具吸引力。有关更多信息,请参阅 选项 2:自动拆分模型

步骤 2:使用 PipelineSchedule 进行执行

现在,我们可以将 PipelineStage 附加到流水线调度,并使用输入数据运行调度。下面是一个 GPipe 示例

from torch.distributed.pipelining import ScheduleGPipe

# Create a schedule
schedule = ScheduleGPipe(stage, n_microbatches)

# Input data (whole batch)
x = torch.randn(batch_size, in_dim, device=device)

# Run the pipeline with input `x`
# `x` will be divided into microbatches automatically
if rank == 0:
    schedule.step(x)
else:
    output = schedule.step()

请注意,以上代码需要为每个工作进程启动,因此我们使用启动器服务来启动多个进程

torchrun --nproc_per_node=2 example.py

拆分模型的选项

选项 1:手动拆分模型

要直接构造 PipelineStage,用户需要提供一个 nn.Module 实例,该实例拥有相关的 nn.Parametersnn.Buffers,并定义一个执行该阶段相关操作的 forward() 方法。例如,Torchtitan 中定义的 Transformer 类的精简版本展示了一种构建易于分区的模型的模式。

class Transformer(nn.Module):
    def __init__(self, model_args: ModelArgs):
        super().__init__()

        self.tok_embeddings = nn.Embedding(...)

        # Using a ModuleDict lets us delete layers witout affecting names,
        # ensuring checkpoints will correctly save and load.
        self.layers = torch.nn.ModuleDict()
        for layer_id in range(model_args.n_layers):
            self.layers[str(layer_id)] = TransformerBlock(...)

        self.output = nn.Linear(...)

    def forward(self, tokens: torch.Tensor):
        # Handling layers being 'None' at runtime enables easy pipeline splitting
        h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens

        for layer in self.layers.values():
            h = layer(h, self.freqs_cis)

        h = self.norm(h) if self.norm else h
        output = self.output(h).float() if self.output else h
        return output

通过以下方式,可以轻松地为每个阶段配置以这种方式定义的模型:首先初始化整个模型(使用元设备以避免 OOM 错误),删除该阶段不需要的层,然后创建一个包装该模型的 PipelineStage。例如

with torch.device("meta"):
    assert num_stages == 2, "This is a simple 2-stage example"

    # we construct the entire model, then delete the parts we do not need for this stage
    # in practice, this can be done using a helper function that automatically divides up layers across stages.
    model = Transformer()

    if stage_index == 0:
        # prepare the first stage model
        del model.layers["1"]
        model.norm = None
        model.output = None

    elif stage_index == 1:
        # prepare the second stage model
        model.tok_embeddings = None
        del model.layers["0"]

    from torch.distributed.pipelining import PipelineStage
    stage = PipelineStage(
        model,
        stage_index,
        num_stages,
        device,
        input_args=example_input_microbatch,
    )

PipelineStage 需要一个示例参数 input_args 来表示阶段的运行时输入,它是一个微批次的输入数据。此参数将传递给阶段模块的 forward 方法,以确定通信所需的输入和输出形状。

当与其他数据或模型并行技术组合使用时,如果模型块的输出形状/数据类型会受到影响,则可能还需要 output_args

选项 2:自动拆分模型

如果您有一个完整的模型,并且不想花费时间将其修改为一系列“模型分区”,则 pipeline API 可以为您提供帮助。以下是一个简短的示例

class Model(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.emb = torch.nn.Embedding(10, 3)
        self.layers = torch.nn.ModuleList(
            Layer() for _ in range(2)
        )
        self.lm = LMHead()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.emb(x)
        for layer in self.layers:
            x = layer(x)
        x = self.lm(x)
        return x

如果我们打印模型,我们可以看到多个层次结构,这使得手动拆分变得很困难

Model(
  (emb): Embedding(10, 3)
  (layers): ModuleList(
    (0-1): 2 x Layer(
      (lin): Linear(in_features=3, out_features=3, bias=True)
    )
  )
  (lm): LMHead(
    (proj): Linear(in_features=3, out_features=3, bias=True)
  )
)

让我们看看 pipeline API 是如何工作的

from torch.distributed.pipelining import pipeline, SplitPoint

x = torch.LongTensor([1, 2, 4, 5])
pipe = pipeline(
    module=mod,
    num_chunks=1,
    example_args=(x,),
    split_spec={
        "layers.1": SplitPoint.BEGINNING,
    }
)

pipeline API 根据 split_spec 拆分模型,其中 SplitPoint.BEGINNING 表示在 forward 函数中执行某个子模块*之前*添加一个拆分点,类似地,SplitPoint.END 表示在*之后*添加拆分点。

如果我们 print(pipe),我们可以看到

GraphModule(
  (submod_0): GraphModule(
    (emb): InterpreterModule()
    (layers): Module(
      (0): InterpreterModule(
        (lin): InterpreterModule()
      )
    )
  )
  (submod_1): GraphModule(
    (layers): Module(
      (1): InterpreterModule(
        (lin): InterpreterModule()
      )
    )
    (lm): InterpreterModule(
      (proj): InterpreterModule()
    )
  )
)

def forward(self, x):
    submod_0 = self.submod_0(x);  x = None
    submod_1 = self.submod_1(submod_0);  submod_0 = None
    return (submod_1,)

“模型分区”由子模块(submod_0submod_1)表示,每个子模块都使用原始模型操作和层次结构重建。此外,还重建了一个“根级”forward 函数,以捕获这些分区之间的数据流。此类数据流稍后将由流水线运行时以分布式方式重放。

Pipe 对象提供了一种检索“模型分区”的方法

stage_mod : nn.Module = pipe.get_stage_module(stage_idx)

您还可以使用 Pipe 在设备上创建分布式阶段运行时

stage = pipe.build_stage(stage_idx, device, group)

注意

pipeline 前端使用跟踪器(torch.export)将模型捕获到单个图中。如果您的模型无法完全图形化,则可以使用下面的手动前端。

Hugging Face 示例

在最初创建此包的 PiPPy 仓库中,我们保留了基于未修改的 Hugging Face 模型的示例。请参阅 examples/huggingface 目录。

示例包括

技术深度解析

pipeline API 如何拆分模型?

首先,pipeline API 通过跟踪模型将模型转换为有向无环图 (DAG)。它使用 torch.export 跟踪模型,这是一个 PyTorch 2 全图捕获工具。

然后,它将一个阶段所需的*操作和参数*组合到一个重建的子模块中:submod_0submod_1……

Module.children() 等传统的子模块访问方法不同,pipeline API 不仅会切割模型的模块结构,还会切割模型的*forward*函数。

这是必要的,因为 Module.children() 等模型结构仅捕获 Module.__init__() 期间的信息,而不捕获有关 Module.forward() 的任何信息。换句话说,Module.children() 缺少有关以下关键流水线化方面的信息

  • forward 中子模块的执行顺序

  • 子模块之间的激活流

  • 子模块之间是否存在任何函数运算符(例如,reluadd 操作不会被 Module.children() 捕获)。

相反,pipeline API 确保真正保留了 forward 行为。它还捕获分区之间的激活流,帮助分布式运行时在无需人工干预的情况下进行正确的发送/接收调用。

pipeline API 的另一个灵活性是拆分点可以位于模型层次结构中的任意级别。在拆分分区中,与该分区相关的原始模型层次结构将被重建,而无需您付出任何代价。因此,指向子模块或参数的完全限定名称 (FQN) 仍然有效,并且依赖 FQN 的服务(例如 FSDP、TP 或检查点)仍然可以使用分区模块运行,几乎无需更改代码。

实现您自己的调度

您可以通过扩展以下两个类之一来实现您自己的流水线调度

  • PipelineScheduleSingle

  • PipelineScheduleMulti

PipelineScheduleSingle 用于为每个秩*仅分配一个*阶段的调度。 PipelineScheduleMulti 用于为每个秩分配多个阶段的调度。

例如,ScheduleGPipeSchedule1F1BPipelineScheduleSingle 的子类。而 ScheduleInterleaved1F1BScheduleLoopedBFSPipelineScheduleMulti 的子类。

API 参考

模型拆分 API

以下 API 集将您的模型转换为流水线表示形式。

class torch.distributed.pipelining.SplitPoint(value)[源代码]

枚举。

torch.distributed.pipelining.pipeline(module, mb_args, mb_kwargs=None, split_spec=None, split_policy=None)[源代码]

根据规范拆分模块。

有关详细信息,请参阅 Pipe

参数
返回类型

Pipe 类的流水线表示形式。

class torch.distributed.pipelining.Pipe(split_gm, num_stages, has_loss_and_backward, loss_spec)[源代码]
torch.distributed.pipelining.pipe_split()[源代码]

pipe_split 是一个特殊运算符,用于标记模块中阶段之间的边界。它用于将模块拆分为多个阶段。如果您的带注释模块是立即运行的,则它是一个空操作。

示例

>>> def forward(self, x):
>>>     x = torch.mm(x, self.mm_param)
>>>     x = torch.relu(x)
>>>     pipe_split()
>>>     x = self.lin(x)
>>>     return x

上面的例子将被分成两个阶段。

微批次实用程序

class torch.distributed.pipelining.microbatch.TensorChunkSpec(split_dim)[源代码]

用于指定输入分块的类

torch.distributed.pipelining.microbatch.split_args_kwargs_into_chunks(args, kwargs, chunks, args_chunk_spec=None, kwargs_chunk_spec=None)[源代码]

给定一系列参数和关键字参数,根据它们各自的分块规范将它们分成若干块。

参数
返回值

分片参数列表 kwargs_split:分片关键字参数列表

返回类型

args_split

torch.distributed.pipelining.microbatch.merge_chunks(chunks, chunk_spec)[源代码]

给定一个块列表,根据块规范将它们合并成一个值。

参数
  • chunks (List[Any]) – 块列表

  • chunk_spec – 块的分块规范

返回值

合并后的值

返回类型

value

管道阶段

class torch.distributed.pipelining.stage.PipelineStage(submodule, stage_index, num_stages, device, input_args, output_args=None, group=None)[源代码]

表示管道并行设置中管道阶段的类。此类是通过提供示例输入(以及可选的输出)手动创建的,而不是 pipeline() 输出的 PipelineStage 类。此类扩展了 _PipelineStageBase 类,并且可以类似地在 PipelineScheule 中使用。

参数
  • submodule (nn.Module) – 此阶段包装的 PyTorch 模块。

  • stage_index (int) – 此阶段的 ID。

  • num_stages (int) – 阶段总数。

  • device (torch.device) – 此阶段所在的设备。

  • input_args (Union[torch.Tensor, Tuple[torch.tensor]], 可选) – 子模块的输入参数。

  • output_args (Union[torch.Tensor, Tuple[torch.tensor]], 可选) – 子模块的输出参数。

  • group (dist.ProcessGroup, 可选) – 用于分布式训练的进程组。如果为 None,则为默认组。

torch.distributed.pipelining.stage.build_stage(stage_module, stage_index, pipe_info, device, group=None)[源代码]

创建一个管道阶段,给定一个要由该阶段包装的 stage_module 和管道信息。

参数
  • stage_module (torch.nn.Module) – 要由该阶段包装的模块

  • stage_index (int) – 该阶段在管道中的索引

  • pipe_info (PipeInfo) – 有关管道的信息,可以通过 pipe.info() 获取

  • device (torch.device) – 该阶段要使用的设备

  • group (Optional[dist.ProcessGroup]) – 该阶段要使用的进程组

返回值

一个可以使用 PipelineSchedules 运行的管道阶段。

返回类型

_PipelineStage

管道调度

class torch.distributed.pipelining.schedules.ScheduleGPipe(stage, n_microbatches, loss_fn=None, args_chunk_spec=None, kwargs_chunk_spec=None, output_merge_spec=None)[源代码]

GPipe 调度。将以填充-清空的方式遍历所有微批次。

class torch.distributed.pipelining.schedules.Schedule1F1B(stage, n_microbatches, loss_fn=None, args_chunk_spec=None, kwargs_chunk_spec=None, output_merge_spec=None)[源代码]

1F1B 调度。将在稳定状态下对微批次执行一次正向和一次反向传播。

class torch.distributed.pipelining.schedules.ScheduleInterleaved1F1B(stages, n_microbatches, loss_fn=None, args_chunk_spec=None, kwargs_chunk_spec=None, output_merge_spec=None)[源代码]

交错 1F1B 调度。有关详细信息,请参阅 https://arxiv.org/pdf/2104.04473。将在稳定状态下对微批次执行一次正向和一次反向传播,并且每个等级支持多个阶段。当微批次准备好用于多个本地阶段时,交错 1F1B 优先考虑较早的微批次(也称为“深度优先”)。

class torch.distributed.pipelining.schedules.ScheduleLoopedBFS(stages, n_microbatches, loss_fn=None, output_merge_spec=None)[源代码]

广度优先流水线并行。有关详细信息,请参阅 https://arxiv.org/abs/2211.05953。与交错式 1F1B 类似,循环 BFS 支持每个秩有多个阶段。不同之处在于,当微批次准备好用于多个本地阶段时,循环 BFS 将优先考虑较早的阶段,一次运行所有可用的微批次。

class torch.distributed.pipelining.schedules.PipelineScheduleSingle(stage, n_microbatches, loss_fn=None, args_chunk_spec=None, kwargs_chunk_spec=None, output_merge_spec=None)[源代码]

单阶段调度基类。实现 step 方法。派生类应实现 _step_microbatches

step(*args, target=None, losses=None, **kwargs)[源代码]

使用_完整批次_输入运行流水线调度的一次迭代。将自动将输入分块为微批次,并根据调度实现遍历微批次。

args:模型的位置参数(如非流水线情况)。kwargs:模型的关键字参数(如非流水线情况)。target:损失函数的目标。losses:用于存储每个微批次损失的列表。

class torch.distributed.pipelining.schedules.PipelineScheduleMulti(stages, n_microbatches, loss_fn=None, args_chunk_spec=None, kwargs_chunk_spec=None, output_merge_spec=None)[源代码]

多阶段调度基类。实现 step 方法。

step(*args, target=None, losses=None, **kwargs)[源代码]

使用_完整批次_输入运行流水线调度的一次迭代。将自动将输入分块为微批次,并根据调度实现遍历微批次。

args:模型的位置参数(如非流水线情况)。kwargs:模型的关键字参数(如非流水线情况)。target:损失函数的目标。losses:用于存储每个微批次损失的列表。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源