快捷方式

torchrun (弹性启动)

模块 torch.distributed.run

torch.distributed.run 是一个模块,用于在每个训练节点上启动多个分布式训练进程。

torchrun 是一个 Python 控制台脚本,指向在 setup.pyentry_points 配置中声明的主模块 torch.distributed.run。它等同于调用 python -m torch.distributed.run

torchrun 可用于单节点分布式训练,其中每个节点将启动一个或多个进程。它可用于 CPU 训练或 GPU 训练。如果用于 GPU 训练,每个分布式进程将在单个 GPU 上运行。这可以显著提高单节点训练性能。torchrun 也可用于多节点分布式训练,通过在每个节点上启动多个进程来提高多节点分布式训练性能。对于具有直接 GPU 支持的多个 Infiniband 接口的系统而言,这尤其有利,因为所有接口都可以用于聚合通信带宽。

在单节点分布式训练或多节点分布式训练这两种情况下,torchrun 将按给定数量在每个节点上启动进程 (--nproc-per-node)。如果用于 GPU 训练,此数量需要小于或等于当前系统上的 GPU 数量 (nproc_per_node),并且每个进程将在单个 GPU 上运行,从 GPU 0 到 GPU (nproc_per_node - 1)

2.0.0 版本中有所更改: torchrun 将把参数 --local-rank=<rank> 传递给您的脚本。从 PyTorch 2.0.0 开始,带短划线的 --local-rank 优先于先前使用的带下划线的 --local_rank

为了向后兼容,用户可能需要在其参数解析代码中处理这两种情况。这意味着在参数解析器中同时包含 "--local-rank""--local_rank"。如果仅提供了 "--local_rank"torchrun 将会触发错误:“error: unrecognized arguments: –local-rank=<rank>”。对于仅支持 PyTorch 2.0.0+ 的训练代码,包含 "--local-rank" 就足够了。

>>> import argparse
>>> parser = argparse.ArgumentParser()
>>> parser.add_argument("--local-rank", "--local_rank", type=int)
>>> args = parser.parse_args()

用法

单节点多工作进程

torchrun
    --standalone
    --nnodes=1
    --nproc-per-node=$NUM_TRAINERS
    YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...)

注意

--nproc-per-node 可以是 "gpu"(为每个 GPU 启动一个进程)、"cpu"(为每个 CPU 启动一个进程)、"auto"(如果 CUDA 可用则等同于 "gpu",否则等同于 "cpu"),或者是一个指定进程数量的整数。更多详情请参见 torch.distributed.run.determine_local_world_size

堆叠式单节点多工作进程

要在同一主机上运行多个单节点多工作进程实例(独立作业),我们需要确保每个实例(作业)设置在不同的端口上,以避免端口冲突(或更糟的是,两个作业合并为一个作业)。为此,您必须使用 --rdzv-backend=c10d 运行,并通过设置 --rdzv-endpoint=localhost:$PORT_k 指定不同的端口。对于 --nodes=1,通常让 torchrun 自动选择一个空闲的随机端口会更方便,而不是为每次运行手动分配不同的端口。

torchrun
    --rdzv-backend=c10d
    --rdzv-endpoint=localhost:0
    --nnodes=1
    --nproc-per-node=$NUM_TRAINERS
    YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...)

容错 (固定数量的工作进程,无弹性,容忍 3 个故障)

torchrun
    --nnodes=$NUM_NODES
    --nproc-per-node=$NUM_TRAINERS
    --max-restarts=3
    --rdzv-id=$JOB_ID
    --rdzv-backend=c10d
    --rdzv-endpoint=$HOST_NODE_ADDR
    YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...)

HOST_NODE_ADDR,格式为 <host>[:<port>](例如 node1.example.com:29400),指定了实例化和托管 C10d 汇合后端的节点和端口。它可以是您训练集群中的任何节点,但理想情况下应选择具有高带宽的节点。

注意

如果未指定端口号,HOST_NODE_ADDR 默认为 29400。

弹性 (min=1, max=4, 最多容忍 3 次成员变更或故障)

torchrun
    --nnodes=1:4
    --nproc-per-node=$NUM_TRAINERS
    --max-restarts=3
    --rdzv-id=$JOB_ID
    --rdzv-backend=c10d
    --rdzv-endpoint=$HOST_NODE_ADDR
    YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...)

HOST_NODE_ADDR,格式为 <host>[:<port>](例如 node1.example.com:29400),指定了实例化和托管 C10d 汇合后端的节点和端口。它可以是您训练集群中的任何节点,但理想情况下应选择具有高带宽的节点。

注意

如果未指定端口号,HOST_NODE_ADDR 默认为 29400。

关于汇合后端的注意事项

对于多节点训练,您需要指定

  1. --rdzv-id: 一个唯一的作业 ID(由参与作业的所有节点共享)

  2. --rdzv-backend: torch.distributed.elastic.rendezvous.RendezvousHandler 的实现

  3. --rdzv-endpoint: 汇合后端运行的端点;通常形式为 host:port

目前内置支持 c10d(推荐)、etcd-v2etcd(遗留)汇合后端。要使用 etcd-v2etcd,请设置一个启用 v2 API(例如 --enable-v2)的 etcd 服务器。

警告

etcd-v2etcd 汇合后端使用 etcd API v2。您必须在 etcd 服务器上启用 v2 API。我们的测试使用 etcd v3.4.3。

警告

对于基于 etcd 的汇合,我们建议使用 etcd-v2 而非 etcd,前者功能上等同,但使用了改进的实现。etcd 已进入维护模式,并将在未来版本中移除。

定义

  1. 节点 (Node) - 一个物理实例或容器;对应于作业管理器处理的单元。

  2. 工作进程 (Worker) - 分布式训练中的一个工作进程。

  3. 工作进程组 (WorkerGroup) - 执行相同功能(例如,训练器)的工作进程集合。

  4. 本地工作进程组 (LocalWorkerGroup) - 运行在同一节点上的工作进程组中的一个子集。

  5. 全局秩 (RANK) - 工作进程在工作进程组中的秩。

  6. 全局工作进程数 (WORLD_SIZE) - 工作进程组中的工作进程总数。

  7. 本地秩 (LOCAL_RANK) - 工作进程在本地工作进程组中的秩。

  8. 本地工作进程数 (LOCAL_WORLD_SIZE) - 本地工作进程组的大小。

  9. rdzv_id - 用户定义的 ID,唯一标识一个作业的工作进程组。此 ID 由每个节点用于加入特定工作进程组作为成员。

  1. rdzv_backend - 汇合的后端(例如 c10d)。这通常是一个强一致性的键值存储。

  2. rdzv_endpoint - 汇合后端端点;通常形式为 <host>:<port>

一个 节点 (Node) 运行 本地工作进程数 (LOCAL_WORLD_SIZE) 个工作进程,这些工作进程构成一个 本地工作进程组 (LocalWorkerGroup)。作业中所有节点的 本地工作进程组 (LocalWorkerGroup) 的并集构成了 工作进程组 (WorkerGroup)

环境变量

以下环境变量在您的脚本中可用

  1. LOCAL_RANK - 本地秩。

  2. RANK - 全局秩。

  3. GROUP_RANK - 工作进程组的秩。一个介于 0 和 max_nnodes 之间的数字。当每个节点运行一个工作进程组时,这表示节点的秩。

  4. ROLE_RANK - 在所有具有相同角色的工作进程中的秩。工作进程的角色在 WorkerSpec 中指定。

  5. LOCAL_WORLD_SIZE - 本地工作进程数(例如,本地运行的工作进程数量);等于在 torchrun 上指定的 --nproc-per-node

  6. WORLD_SIZE - 全局工作进程数(作业中的工作进程总数)。

  7. ROLE_WORLD_SIZE - 以 WorkerSpec 中指定的相同角色启动的工作进程总数。

  8. MASTER_ADDR - 运行秩为 0 的工作进程的主机的 FQDN(完全限定域名);用于初始化 Torch 分布式后端。

  9. MASTER_PORT - MASTER_ADDR 上可用于托管 C10d TCP store 的端口。

  10. TORCHELASTIC_RESTART_COUNT - 工作进程组到目前为止的重启次数。

  11. TORCHELASTIC_MAX_RESTARTS - 配置的最大重启次数。

  12. TORCHELASTIC_RUN_ID - 等于汇合 run_id(例如,唯一的作业 ID)。

  13. PYTHON_EXEC - 系统可执行文件覆盖。如果提供,Python 用户脚本将使用 PYTHON_EXEC 的值作为可执行文件。默认使用 sys.executable

部署

  1. (C10d 后端不需要)启动汇合后端服务器并获取端点(作为 --rdzv-endpoint 传递给 torchrun

  2. 单节点多工作进程:在主机上启动 torchrun 以启动代理进程,该进程创建并监控本地工作进程组。

  3. 多节点多工作进程:在所有参与训练的节点上使用相同的参数启动 torchrun

使用作业/集群管理器时,多节点作业的入口点命令应该是 torchrun

故障模式

  1. 工作进程故障:对于具有 n 个工作进程的训练作业,如果 k<=n 个工作进程发生故障,所有工作进程将被停止并重启,最多重启 max_restarts 次。

  2. 代理故障:代理故障导致本地工作进程组故障。作业管理器可以选择使整个作业失败(组语义)或尝试替换节点。这两种行为都由代理支持。

  3. 节点故障:与代理故障相同。

成员变更

  1. 节点离开(缩容):代理收到离开通知,所有现有工作进程停止,形成新的 WorkerGroup,所有工作进程使用新的 RANKWORLD_SIZE 启动。

  2. 节点加入(扩容):新节点被添加到作业中,所有现有工作进程停止,形成新的 WorkerGroup,所有工作进程使用新的 RANKWORLD_SIZE 启动。

重要注意事项

  1. 该工具和多进程分布式(单节点或多节点)GPU 训练目前仅在使用 NCCL 分布式后端时能获得最佳性能。因此,NCCL 后端是 GPU 训练推荐使用的后端。

  2. 初始化 Torch 进程组所需的环境变量由该模块提供,您无需手动传递 RANK。要在您的训练脚本中初始化进程组,只需运行

>>> import torch.distributed as dist
>>> dist.init_process_group(backend="gloo|nccl")
  1. 在您的训练程序中,您可以使用常规的分布式函数,或使用 torch.nn.parallel.DistributedDataParallel() 模块。如果您的训练程序使用 GPU 进行训练并希望使用 torch.nn.parallel.DistributedDataParallel() 模块,以下是配置方法。

local_rank = int(os.environ["LOCAL_RANK"])
model = torch.nn.parallel.DistributedDataParallel(
    model, device_ids=[local_rank], output_device=local_rank
)

请确保 device_ids 参数设置为您的代码将操作的唯一 GPU 设备 ID。这通常是进程的本地秩。换句话说,为了使用此工具,device_ids 需要设置为 [int(os.environ("LOCAL_RANK"))],并且 output_device 需要设置为 int(os.environ("LOCAL_RANK"))

  1. 在故障或成员变更发生时,所有存活的工作进程都会立即终止。请务必定期保存检查点。检查点的频率应取决于您的作业对丢失工作的容忍度。

  2. 此模块仅支持同构的 本地工作进程数 (LOCAL_WORLD_SIZE)。也就是说,假定所有节点运行相同数量的本地工作进程(按角色)。

  3. 全局秩 (RANK) 是不稳定的。在重启之间,节点上的本地工作进程可能被分配与之前不同的秩范围。切勿硬编码关于秩稳定性或 RANKLOCAL_RANK 之间相关性的任何假设。

  4. 使用弹性功能(min_size!=max_size)时,请勿硬编码关于 全局工作进程数 (WORLD_SIZE) 的假设,因为随着节点允许离开和加入,全局工作进程数可能会发生变化。

  5. 建议您的脚本采用以下结构

def main():
    load_checkpoint(checkpoint_path)
    initialize()
    train()


def train():
    for batch in iter(dataset):
        train_step(batch)

        if should_checkpoint:
            save_checkpoint(checkpoint_path)
  1. (推荐)当工作进程发生错误时,此工具将汇总错误详情(例如,时间、秩、主机、进程 ID、堆栈跟踪等)。在每个节点上,第一个错误(按时间戳)会被启发式地报告为“根本原因”错误。要将堆栈跟踪作为错误摘要输出的一部分,您必须按照下面的示例所示,修饰训练脚本中的主入口点函数。如果未修饰,则摘要将不包含异常的堆栈跟踪,仅包含退出码。有关 torchelastic 错误处理的详细信息,请参见:https://pytorch.ac.cn/docs/stable/elastic/errors.html

from torch.distributed.elastic.multiprocessing.errors import record


@record
def main():
    # do train
    pass


if __name__ == "__main__":
    main()

文档

查阅全面的 PyTorch 开发者文档

查看文档

教程

获取针对初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题解答

查看资源