分布式流水线并行简介¶
创建日期:2024 年 7 月 9 日 | 最后更新:2024 年 12 月 12 日 | 最后验证:2024 年 11 月 5 日
作者: Howard Huang
注意
在 github 上查看和编辑本教程。
本教程使用 gpt 风格的 Transformer 模型来演示如何使用 torch.distributed.pipelining API 实现分布式流水线并行。
如何使用
torch.distributed.pipelining
API如何将流水线并行应用于 Transformer 模型
如何在一组微批次上利用不同的调度策略
熟悉 PyTorch 中的基本分布式训练
设置¶
使用 torch.distributed.pipelining
,我们将对模型的执行进行分区,并在微批次上调度计算。我们将使用简化版的 Transformer 解码器模型。该模型架构仅用于教学目的,包含多个 Transformer 解码器层,以便我们演示如何将模型分割成不同的块。首先,让我们定义模型
import torch
import torch.nn as nn
from dataclasses import dataclass
@dataclass
class ModelArgs:
dim: int = 512
n_layers: int = 8
n_heads: int = 8
vocab_size: int = 10000
class Transformer(nn.Module):
def __init__(self, model_args: ModelArgs):
super().__init__()
self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim)
# Using a ModuleDict lets us delete layers witout affecting names,
# ensuring checkpoints will correctly save and load.
self.layers = torch.nn.ModuleDict()
for layer_id in range(model_args.n_layers):
self.layers[str(layer_id)] = nn.TransformerDecoderLayer(model_args.dim, model_args.n_heads)
self.norm = nn.LayerNorm(model_args.dim)
self.output = nn.Linear(model_args.dim, model_args.vocab_size)
def forward(self, tokens: torch.Tensor):
# Handling layers being 'None' at runtime enables easy pipeline splitting
h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens
for layer in self.layers.values():
h = layer(h, h)
h = self.norm(h) if self.norm else h
output = self.output(h).clone() if self.output else h
return output
然后,我们需要在脚本中导入必要的库并初始化分布式训练过程。在本例中,我们定义了一些全局变量以便稍后在脚本中使用
import os
import torch.distributed as dist
from torch.distributed.pipelining import pipeline, SplitPoint, PipelineStage, ScheduleGPipe
global rank, device, pp_group, stage_index, num_stages
def init_distributed():
global rank, device, pp_group, stage_index, num_stages
rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])
device = torch.device(f"cuda:{rank}") if torch.cuda.is_available() else torch.device("cpu")
dist.init_process_group()
# This group can be a sub-group in the N-D parallel case
pp_group = dist.new_group()
stage_index = rank
num_stages = world_size
The rank
, world_size
, and init_process_group()
code should seem familiar to you as those are commonly used in all distributed programs. The globals specific to pipeline parallelism include pp_group
which is the process group that will be used for send/recv communications, stage_index
which, in this example, is a single rank per stage so the index is equivalent to the rank, and num_stages
which is equivalent to world_size.
The num_stages
is used to set the number of stages that will be used in the pipeline parallelism schedule. For example, for num_stages=4
, a microbatch will need to go through 4 forwards and 4 backwards before it is completed. The stage_index
is necessary for the framework to know how to communicate between stages. For example, for the first stage (stage_index=0
), it will use data from the dataloader and does not need to receive data from any previous peers to perform its computation.
步骤 1:对 Transformer 模型进行分区¶
对模型进行分区有两种不同的方式
第一种是手动模式,我们可以通过删除模型的部分属性来手动创建模型的两个实例。在这个两个阶段(2 个 rank)的示例中,模型被分成两半。
def manual_model_split(model) -> PipelineStage:
if stage_index == 0:
# prepare the first stage model
for i in range(4, 8):
del model.layers[str(i)]
model.norm = None
model.output = None
elif stage_index == 1:
# prepare the second stage model
for i in range(4):
del model.layers[str(i)]
model.tok_embeddings = None
stage = PipelineStage(
model,
stage_index,
num_stages,
device,
)
return stage
正如我们所见,第一阶段没有层归一化或输出层,并且只包含前四个 Transformer 块。第二阶段没有输入嵌入层,但包含输出层和最后四个 Transformer 块。然后该函数返回当前 rank 的 PipelineStage
。
第二种方法是基于跟踪器的模式,它根据 split_spec
参数自动分割模型。使用流水线规范,我们可以指示 torch.distributed.pipelining
在何处分割模型。在下面的代码块中,我们在第四个 Transformer 解码器层之前进行分割,镜像了上面描述的手动分割。类似地,在完成此分割后,我们可以通过调用 build_stage
来检索一个 PipelineStage
。
步骤 2:定义主执行流程¶
在主函数中,我们将创建一个特定的流水线调度策略,各阶段应遵循该策略。torch.distributed.pipelining
支持多种调度策略,包括单 rank 单阶段调度策略 GPipe
和 1F1B
,以及单 rank 多阶段调度策略,例如 Interleaved1F1B
和 LoopedBFS
。
if __name__ == "__main__":
init_distributed()
num_microbatches = 4
model_args = ModelArgs()
model = Transformer(model_args)
# Dummy data
x = torch.ones(32, 500, dtype=torch.long)
y = torch.randint(0, model_args.vocab_size, (32, 500), dtype=torch.long)
example_input_microbatch = x.chunk(num_microbatches)[0]
# Option 1: Manual model splitting
stage = manual_model_split(model)
# Option 2: Tracer model splitting
# stage = tracer_model_split(model, example_input_microbatch)
model.to(device)
x = x.to(device)
y = y.to(device)
def tokenwise_loss_fn(outputs, targets):
loss_fn = nn.CrossEntropyLoss()
outputs = outputs.reshape(-1, model_args.vocab_size)
targets = targets.reshape(-1)
return loss_fn(outputs, targets)
schedule = ScheduleGPipe(stage, n_microbatches=num_microbatches, loss_fn=tokenwise_loss_fn)
if rank == 0:
schedule.step(x)
elif rank == 1:
losses = []
output = schedule.step(target=y, losses=losses)
print(f"losses: {losses}")
dist.destroy_process_group()
在上面的示例中,我们使用手动方法分割模型,但可以取消注释代码以尝试基于跟踪器的模型分割函数。在我们的调度策略中,我们需要传入微批次的数量以及用于评估目标的损失函数。
The .step()
function processes the entire minibatch and automatically splits it into microbatches based on the n_microbatches
passed previously. The microbatches are then operated on according to the schedule class. In the example above, we are using GPipe, which follows a simple all-forwards and then all-backwards schedule. The output returned from rank 1 will be the same as if the model was on a single GPU and run with the entire batch. Similarly, we can pass in a losses
container to store the corresponding losses for each microbatch.
步骤 3:启动分布式进程¶
最后,我们准备好运行脚本了。我们将使用 torchrun
创建一个单主机、2 进程的任务。我们的脚本已经编写好,使得 rank 0 执行流水线阶段 0 所需的逻辑,而 rank 1 执行流水线阶段 1 的逻辑。
torchrun --nnodes 1 --nproc_per_node 2 pipelining_tutorial.py
结论¶
在本教程中,我们学习了如何使用 PyTorch 的 torch.distributed.pipelining
API 实现分布式流水线并行。我们探讨了如何设置环境、定义 Transformer 模型以及将其分区用于分布式训练。我们讨论了两种模型分区方法:手动和基于跟踪器的分区,并演示了如何在不同阶段的微批次上调度计算。最后,我们介绍了流水线调度的执行以及使用 torchrun
启动分布式进程。
其他资源¶
我们已成功将 torch.distributed.pipelining
集成到 torchtitan 仓库中。TorchTitan 是一个用于使用原生 PyTorch 进行大规模 LLM 训练的干净、精简的代码库。有关流水线并行以及与其他分布式技术的组合在生产环境中的实际用法,请参阅 TorchTitan 3D 并行端到端示例。