• 教程 >
  • 使用 torchrun 进行容错分布式训练
快捷方式

入门 || 什么是 DDP || 单节点多 GPU 训练 || 容错 || 多节点训练 || minGPT 训练

使用 torchrun 进行容错分布式训练

创建日期:2022年9月27日 | 最后更新:2024年11月12日 | 最后验证:2024年11月5日

作者:Suraj Subramanian

你将学到什么
  • 使用 torchrun 启动多 GPU 训练任务

  • 保存和加载训练任务快照

  • 构建训练脚本以实现优雅重启

GitHub 上查看本教程使用的代码

前置条件
  • DDP 的高层级概览

  • 熟悉 DDP 代码

  • 一台配备多个 GPU 的机器(本教程使用 AWS p3.8xlarge 实例)

  • 安装支持 CUDA 的 PyTorch

请观看下方视频或在 YouTube 上观看。

在分布式训练中,单个进程故障可能会中断整个训练任务。由于这里发生故障的可能性更高,因此使你的训练脚本具有鲁棒性尤其重要。你可能还希望训练任务具有弹性,例如,计算资源可以在任务执行期间动态加入和离开。

PyTorch 提供了一个名为 torchrun 的实用程序,它提供了容错和弹性训练功能。当发生故障时,torchrun 会记录错误并尝试从上次保存的训练任务“快照”自动重启所有进程。

快照不仅保存模型状态,还可以包含已运行的 epoch 数量、优化器状态或训练任务连续性所需的任何其他有状态属性的详细信息。

为何使用 torchrun

torchrun 处理分布式训练的细节,因此你无需关心这些。例如,

  • 你无需设置环境变量或显式传递 rankworld_sizetorchrun 会自动分配这些以及其他几个环境变量

  • 你无需在脚本中调用 mp.spawn;你只需要一个通用的 main() 入口点,然后使用 torchrun 启动脚本。这样,同一个脚本可以在非分布式、单节点和多节点环境中运行。

  • 从上次保存的训练快照优雅地重启训练。

优雅重启

为了实现优雅重启,你应该这样组织你的训练脚本:

def main():
  load_snapshot(snapshot_path)
  initialize()
  train()

def train():
  for batch in iter(dataset):
    train_step(batch)

    if should_checkpoint:
      save_snapshot(snapshot_path)

如果发生故障,torchrun 将终止所有进程并重新启动它们。每个进程入口点首先加载并初始化上次保存的快照,然后从那里继续训练。因此,在任何故障发生时,你只会丢失上次保存快照之后的训练进度。

在弹性训练中,无论何时发生成员变化(添加或移除节点),torchrun 都会终止并在可用设备上生成进程。拥有这种结构可以确保你的训练任务可以在无需手动干预的情况下继续进行。

multigpu.pymultigpu_torchrun.py 的差异对比

进程组初始化

- def ddp_setup(rank, world_size):
+ def ddp_setup():
-     """
-     Args:
-         rank: Unique identifier of each process
-         world_size: Total number of processes
-     """
-     os.environ["MASTER_ADDR"] = "localhost"
-     os.environ["MASTER_PORT"] = "12355"
-     init_process_group(backend="nccl", rank=rank, world_size=world_size)
+     init_process_group(backend="nccl")
     torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))

使用 torchrun 提供的环境变量

- self.gpu_id = gpu_id
+ self.gpu_id = int(os.environ["LOCAL_RANK"])

保存和加载快照

定期将所有相关信息存储在快照中,可以使我们的训练任务在中断后无缝恢复。

+ def _save_snapshot(self, epoch):
+     snapshot = {}
+     snapshot["MODEL_STATE"] = self.model.module.state_dict()
+     snapshot["EPOCHS_RUN"] = epoch
+     torch.save(snapshot, "snapshot.pt")
+     print(f"Epoch {epoch} | Training snapshot saved at snapshot.pt")

+ def _load_snapshot(self, snapshot_path):
+     snapshot = torch.load(snapshot_path)
+     self.model.load_state_dict(snapshot["MODEL_STATE"])
+     self.epochs_run = snapshot["EPOCHS_RUN"]
+     print(f"Resuming training from snapshot at Epoch {self.epochs_run}")

在 Trainer 构造函数中加载快照

当重启中断的训练任务时,你的脚本将首先尝试加载快照以从中恢复训练。

class Trainer:
   def __init__(self, snapshot_path, ...):
   ...
+  if os.path.exists(snapshot_path):
+     self._load_snapshot(snapshot_path)
   ...

恢复训练

训练可以从上次运行的 epoch 恢复,而不是从头开始。

def train(self, max_epochs: int):
-  for epoch in range(max_epochs):
+  for epoch in range(self.epochs_run, max_epochs):
      self._run_epoch(epoch)

运行脚本

就像运行非多进程脚本一样简单地调用你的入口点函数;torchrun 会自动生成进程。

if __name__ == "__main__":
   import sys
   total_epochs = int(sys.argv[1])
   save_every = int(sys.argv[2])
-  world_size = torch.cuda.device_count()
-  mp.spawn(main, args=(world_size, total_epochs, save_every,), nprocs=world_size)
+  main(save_every, total_epochs)
- python multigpu.py 50 10
+ torchrun --standalone --nproc_per_node=4 multigpu_torchrun.py 50 10

为本教程评分

© 版权所有 2024, PyTorch.

使用 Sphinx 构建,主题由 Read the Docs 提供。

文档

访问 PyTorch 全面开发者文档

查看文档

教程

获取针对初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获取问题解答

查看资源