训练脚本¶
如果您的训练脚本适用于 torch.distributed.launch
,则它将继续适用于 torchrun
,但存在以下差异
无需手动传递
RANK
、WORLD_SIZE
、MASTER_ADDR
和MASTER_PORT
。可以提供
rdzv_backend
和rdzv_endpoint
。对于大多数用户,这将设置为c10d
(请参阅汇合)。默认的rdzv_backend
会创建一个非弹性汇合,其中rdzv_endpoint
保存主地址。确保您的脚本中包含
load_checkpoint(path)
和save_checkpoint(path)
逻辑。当任意数量的工作进程失败时,我们会使用相同的程序参数重启所有工作进程,因此您将丢失到最近检查点的进度(请参阅弹性启动)。use_env
标志已移除。如果您之前通过解析--local-rank
选项来解析本地排名,则需要从环境变量LOCAL_RANK
中获取本地排名(例如,int(os.environ["LOCAL_RANK"])
)。
以下是一个训练脚本的说明性示例,该脚本在每个时期进行检查点,因此在失败时最坏情况下丢失的进度是一个完整时期的训练量。
def main():
args = parse_args(sys.argv[1:])
state = load_checkpoint(args.checkpoint_path)
initialize(state)
# torch.distributed.run ensures that this will work
# by exporting all the env vars needed to initialize the process group
torch.distributed.init_process_group(backend=args.backend)
for i in range(state.epoch, state.total_num_epochs)
for batch in iter(state.dataset)
train(batch, state.model)
state.epoch += 1
save_checkpoint(state)
有关符合 torchelastic 规范的训练脚本的具体示例,请访问我们的示例页面。