训练脚本¶
如果您的训练脚本适用于 torch.distributed.launch,它将继续适用于 torchrun,但存在以下差异
- 无需手动传递 - RANK、- WORLD_SIZE、- MASTER_ADDR和- MASTER_PORT。
- rdzv_backend和- rdzv_endpoint可以提供。对于大多数用户,这将设置为- c10d(请参阅 rendezvous)。默认- rdzv_backend创建一个非弹性 rendezvous,其中- rdzv_endpoint保存主地址。
- 确保在脚本中具有 - load_checkpoint(path)和- save_checkpoint(path)逻辑。当任意数量的 worker 失败时,我们会使用相同的程序参数重新启动所有 worker,因此您将丢失最新检查点之前的进度(请参阅 elastic launch)。
- use_env标志已被移除。如果您通过解析- --local-rank选项来解析本地 rank,则需要从环境变量- LOCAL_RANK(例如- int(os.environ["LOCAL_RANK"]))获取本地 rank。
以下是训练脚本的说明性示例,该脚本在每个 epoch 进行检查点,因此失败时丢失的最坏情况进度是一个完整的 epoch 训练。
def main():
     args = parse_args(sys.argv[1:])
     state = load_checkpoint(args.checkpoint_path)
     initialize(state)
     # torch.distributed.run ensures that this will work
     # by exporting all the env vars needed to initialize the process group
     torch.distributed.init_process_group(backend=args.backend)
     for i in range(state.epoch, state.total_num_epochs)
          for batch in iter(state.dataset)
              train(batch, state.model)
          state.epoch += 1
          save_checkpoint(state)
有关符合 torchelastic 的训练脚本的具体示例,请访问我们的 示例 页面。