在 GPU 上运行 SPMD¶
PyTorch/XLA 支持在 NVIDIA GPU 上运行 SPMD(单节点或多节点)。训练/推理脚本与用于 TPU 的脚本相同,例如此 ResNet 脚本。要使用 SPMD 执行脚本,我们利用 torchrun
PJRT_DEVICE=CUDA \
torchrun \
--nnodes=${NUM_GPU_MACHINES} \
--node_rank=${RANK_OF_CURRENT_MACHINE} \
--nproc_per_node=1 \
--rdzv_endpoint="<MACHINE_0_IP_ADDRESS>:<PORT>" \
training_or_inference_script_using_spmd.py
--nnodes
:要使用的 GPU 机器数量。--node_rank
:当前 GPU 机器的索引。值可以是 0, 1, …, ${NUMBER_GPU_VM}-1。--nproc_per_node
:由于 SPMD 要求,该值必须为 1。--rdzv_endpoint
:node_rank==0 的 GPU 机器的端点,格式为host:port
。host 将是内部 IP 地址。port
可以是机器上的任何可用端口。对于单节点训练/推理,可以省略此参数。
例如,如果您想在 2 台 GPU 机器上使用 SPMD 训练 ResNet 模型,则可以在第一台机器上运行以下脚本
XLA_USE_SPMD=1 PJRT_DEVICE=CUDA \
torchrun \
--nnodes=2 \
--node_rank=0 \
--nproc_per_node=1 \
--rdzv_endpoint="<MACHINE_0_INTERNAL_IP_ADDRESS>:12355" \
pytorch/xla/test/spmd/test_train_spmd_imagenet.py --fake_data --batch_size 128
并在第二台机器上运行以下命令
XLA_USE_SPMD=1 PJRT_DEVICE=CUDA \
torchrun \
--nnodes=2 \
--node_rank=1 \
--nproc_per_node=1 \
--rdzv_endpoint="<MACHINE_0_INTERNAL_IP_ADDRESS>:12355" \
pytorch/xla/test/spmd/test_train_spmd_imagenet.py --fake_data --batch_size 128
有关更多信息,请参阅 GPU 上 SPMD 支持 RFC。