快捷方式

训练脚本

如果您的训练脚本适用于 torch.distributed.launch,则它将继续适用于 torchrun,但存在以下差异

  1. 无需手动传递 RANKWORLD_SIZEMASTER_ADDRMASTER_PORT

  2. 可以提供 rdzv_backendrdzv_endpoint。对于大多数用户,这将设置为 c10d(请参阅汇合)。默认的 rdzv_backend 会创建一个非弹性汇合,其中 rdzv_endpoint 保存主地址。

  3. 确保您的脚本中包含 load_checkpoint(path)save_checkpoint(path) 逻辑。当任意数量的工作进程失败时,我们会使用相同的程序参数重启所有工作进程,因此您将丢失到最近检查点的进度(请参阅弹性启动)。

  4. use_env 标志已移除。如果您之前通过解析 --local-rank 选项来解析本地排名,则需要从环境变量 LOCAL_RANK 中获取本地排名(例如,int(os.environ["LOCAL_RANK"]))。

以下是一个训练脚本的说明性示例,该脚本在每个时期进行检查点,因此在失败时最坏情况下丢失的进度是一个完整时期的训练量。

def main():
     args = parse_args(sys.argv[1:])
     state = load_checkpoint(args.checkpoint_path)
     initialize(state)

     # torch.distributed.run ensures that this will work
     # by exporting all the env vars needed to initialize the process group
     torch.distributed.init_process_group(backend=args.backend)

     for i in range(state.epoch, state.total_num_epochs)
          for batch in iter(state.dataset)
              train(batch, state.model)

          state.epoch += 1
          save_checkpoint(state)

有关符合 torchelastic 规范的训练脚本的具体示例,请访问我们的示例页面。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源