分布式流水线并行简介¶
创建于:2024 年 7 月 9 日 | 最后更新:2024 年 12 月 12 日 | 最后验证:2024 年 11 月 05 日
作者: Howard Huang
注意
在 github 上查看和编辑本教程。
本教程使用 gpt 风格的 Transformer 模型来演示如何使用 torch.distributed.pipelining API 实现分布式流水线并行。
如何使用
torch.distributed.pipelining
API如何将流水线并行应用于 Transformer 模型
如何在微批次集合上利用不同的调度
熟悉 PyTorch 中的 基本分布式训练
设置¶
使用 torch.distributed.pipelining
,我们将对模型的执行进行分区,并在微批次上调度计算。我们将使用简化版本的 Transformer 解码器模型。模型架构用于教育目的,并具有多个 Transformer 解码器层,因为我们想演示如何将模型拆分为不同的块。首先,让我们定义模型
import torch
import torch.nn as nn
from dataclasses import dataclass
@dataclass
class ModelArgs:
dim: int = 512
n_layers: int = 8
n_heads: int = 8
vocab_size: int = 10000
class Transformer(nn.Module):
def __init__(self, model_args: ModelArgs):
super().__init__()
self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim)
# Using a ModuleDict lets us delete layers witout affecting names,
# ensuring checkpoints will correctly save and load.
self.layers = torch.nn.ModuleDict()
for layer_id in range(model_args.n_layers):
self.layers[str(layer_id)] = nn.TransformerDecoderLayer(model_args.dim, model_args.n_heads)
self.norm = nn.LayerNorm(model_args.dim)
self.output = nn.Linear(model_args.dim, model_args.vocab_size)
def forward(self, tokens: torch.Tensor):
# Handling layers being 'None' at runtime enables easy pipeline splitting
h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens
for layer in self.layers.values():
h = layer(h, h)
h = self.norm(h) if self.norm else h
output = self.output(h).clone() if self.output else h
return output
然后,我们需要在脚本中导入必要的库并初始化分布式训练过程。在本例中,我们定义了一些全局变量以供稍后在脚本中使用
import os
import torch.distributed as dist
from torch.distributed.pipelining import pipeline, SplitPoint, PipelineStage, ScheduleGPipe
global rank, device, pp_group, stage_index, num_stages
def init_distributed():
global rank, device, pp_group, stage_index, num_stages
rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])
device = torch.device(f"cuda:{rank}") if torch.cuda.is_available() else torch.device("cpu")
dist.init_process_group()
# This group can be a sub-group in the N-D parallel case
pp_group = dist.new_group()
stage_index = rank
num_stages = world_size
rank
、world_size
和 init_process_group()
代码对您来说应该很熟悉,因为这些代码在所有分布式程序中都很常用。特定于流水线并行的全局变量包括 pp_group
,它是将用于发送/接收通信的进程组;stage_index
,在本例中,每个阶段只有一个 rank,因此索引等同于 rank;以及 num_stages
,它等同于 world_size。
num_stages
用于设置流水线并行调度中将使用的阶段数。例如,对于 num_stages=4
,一个微批次需要经过 4 次前向和 4 次反向传播才能完成。stage_index
对于框架了解如何在阶段之间进行通信是必要的。例如,对于第一阶段 (stage_index=0
),它将使用来自数据加载器的数据,并且不需要从任何先前的对等方接收数据来执行其计算。
步骤 1:划分 Transformer 模型¶
有两种不同的模型划分方法
第一种是手动模式,在这种模式下,我们可以通过删除模型属性的部分来手动创建模型的两个实例。在本例中,对于两个阶段(2 个 rank),模型被分成两半。
def manual_model_split(model) -> PipelineStage:
if stage_index == 0:
# prepare the first stage model
for i in range(4, 8):
del model.layers[str(i)]
model.norm = None
model.output = None
elif stage_index == 1:
# prepare the second stage model
for i in range(4):
del model.layers[str(i)]
model.tok_embeddings = None
stage = PipelineStage(
model,
stage_index,
num_stages,
device,
)
return stage
正如我们所见,第一阶段没有层归一化或输出层,它只包括前四个 Transformer 块。第二阶段没有输入嵌入层,但包括输出层和最后四个 Transformer 块。然后,该函数返回当前 rank 的 PipelineStage
。
第二种方法是基于 Tracer 的模式,它根据 split_spec
参数自动拆分模型。使用流水线规范,我们可以指示 torch.distributed.pipelining
在何处拆分模型。在以下代码块中,我们将在第 4 个 Transformer 解码器层之前进行拆分,这与上面描述的手动拆分相呼应。同样,我们可以通过在完成拆分后调用 build_stage
来检索 PipelineStage
。
步骤 2:定义主执行¶
在主函数中,我们将创建一个阶段应遵循的特定流水线调度。torch.distributed.pipelining
支持多种调度,包括单阶段每 rank 调度 GPipe
和 1F1B
,以及多阶段每 rank 调度,例如 Interleaved1F1B
和 LoopedBFS
。
if __name__ == "__main__":
init_distributed()
num_microbatches = 4
model_args = ModelArgs()
model = Transformer(model_args)
# Dummy data
x = torch.ones(32, 500, dtype=torch.long)
y = torch.randint(0, model_args.vocab_size, (32, 500), dtype=torch.long)
example_input_microbatch = x.chunk(num_microbatches)[0]
# Option 1: Manual model splitting
stage = manual_model_split(model)
# Option 2: Tracer model splitting
# stage = tracer_model_split(model, example_input_microbatch)
model.to(device)
x = x.to(device)
y = y.to(device)
def tokenwise_loss_fn(outputs, targets):
loss_fn = nn.CrossEntropyLoss()
outputs = outputs.reshape(-1, model_args.vocab_size)
targets = targets.reshape(-1)
return loss_fn(outputs, targets)
schedule = ScheduleGPipe(stage, n_microbatches=num_microbatches, loss_fn=tokenwise_loss_fn)
if rank == 0:
schedule.step(x)
elif rank == 1:
losses = []
output = schedule.step(target=y, losses=losses)
print(f"losses: {losses}")
dist.destroy_process_group()
在上面的示例中,我们使用手动方法来拆分模型,但可以取消注释代码以尝试基于 Tracer 的模型拆分函数。在我们的调度中,我们需要传入微批次的数量和用于评估目标的损失函数。
.step()
函数处理整个小批量,并根据先前传递的 n_microbatches
自动将其拆分为微批次。然后,根据调度类对微批次进行操作。在上面的示例中,我们使用 GPipe,它遵循简单的前向全部和后向全部调度。从 rank 1 返回的输出将与模型在单个 GPU 上运行并使用整个批次运行的输出相同。同样,我们可以传入 losses
容器来存储每个微批次的相应损失。
步骤 3:启动分布式进程¶
最后,我们准备好运行脚本了。我们将使用 torchrun
创建单主机、2 进程作业。我们的脚本已经以 rank 0 执行流水线阶段 0 所需的逻辑,rank 1 执行流水线阶段 1 的逻辑的方式编写。
torchrun --nnodes 1 --nproc_per_node 2 pipelining_tutorial.py
结论¶
在本教程中,我们学习了如何使用 PyTorch 的 torch.distributed.pipelining
API 实现分布式流水线并行。我们探讨了设置环境、定义 Transformer 模型以及对其进行分区以进行分布式训练。我们讨论了模型分区的两种方法:手动和基于 Tracer 的方法,并演示了如何在不同阶段的微批次上调度计算。最后,我们介绍了流水线调度的执行以及使用 torchrun
启动分布式进程。
其他资源¶
我们已成功将 torch.distributed.pipelining
集成到 torchtitan 存储库中。TorchTitan 是一个干净、最小的代码库,用于使用原生 PyTorch 进行大规模 LLM 训练。有关流水线并行的生产就绪用法以及与其他分布式技术的组合,请参阅 TorchTitan 端到端 3D 并行示例。