快捷方式

PJRT 运行时

PyTorch/XLA 已从基于 TensorFlow 的 XRT 运行时迁移到 JAX 使用的 PJRT 运行时

如果您在使用 PJRT 时遇到错误,请在 GitHub 上提交一个带有 runtime 标签的问题。

PyTorch/XLA r2.1 中的新特性:

  • PJRT 在 PyTorch/XLA r2.1 中已稳定!

  • 公共运行时 API 已从 torch_xla.experimental.pjrt 迁移到 torch_xla.runtime

    • pjrt:// init 方法已重命名为 xla://,并通过 torch_xla.distributed.xla_backend 进行注册。

    • 为了兼容性,之前的 torch_xla.experimental.* 名称在此版本中仍然可用。

  • 使用 init_method='xla://' 时,现在支持 torchrun

  • 通过 PJRT C API 为 XPU 和 Neuron 提供了新的插件。

PyTorch/XLA r2.0 中的新特性:

  • 如果您不传入任何其他运行时配置,PJRT 将默认配置。如果您继续设置 XRT 配置 (XRT_TPU_CONFIG),此更改没有影响。

  • libtpu 中的新 TPU 运行时实现将性能提升高达 30%。

  • 新的 xm.rendezvous 实现,可扩展到数千个 TPU core

  • [实验性] 支持 TPU v2 和 v3 的 torch.distributed,包括 pjrt:// init_method

概览

  • 要使用 PJRT 预览版运行时,请将 PJRT_DEVICE 环境变量设置为 CPUTPUCUDA

  • 在 XRT 中,所有分布式工作负载都是多进程的,每个设备一个进程。在 PJRT 的 TPU v2 和 v3 上,工作负载是多进程且多线程的(4 个进程,每个进程 2 个线程),因此您的工作负载应该是线程安全的。有关更多信息,请参见 TPU v2/v3 上的多线程API 指南的多进程部分。需要记住的关键差异如下:

    • 要以线程安全的方式初始化模型,可以在初始化后通过 torch_xla.experimental.pjrt.broadcast_master_param 将参数广播到所有副本,或者从一个共同的检查点加载每个副本的参数。

    • 对于其他随机数生成,尽可能使用 torch.Generator。全局 torch RNG 不是线程安全的,即使您在所有副本上设置相同的 torch.manual_seed

    • 要使用 torch.distributed,请导入 torch_xla.experimental.pjrt_backend 并使用 xla:// init_method

    • 对于 GPU 和 TPU v4,这些步骤是可选的。

XRT 到 PJRT 的示例差异

import os

import torch
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.optim as optim
import torch.distributed as dist
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_backend
+import torch_xla.runtime as xr


def _mp_fn(index):
  device = xm.xla_device()
-  dist.init_process_group('xla', rank=xr.global_ordinal(), world_size=xr.world_size())
+  dist.init_process_group('xla', init_method='xla://')

  torch.manual_seed(42)
  model = nn.Linear(128, 10).to(device)

+  # Optional for TPU v4 and GPU
+  xm.broadcast_master_param(model)
  model = DDP(model, gradient_as_bucket_view=True)

  loss_fn = nn.MSELoss()
  optimizer = optim.SGD(model.parameters(), lr=.001)

  for i in range(10):
    data, target = torch.randn((128, 128), device=device), torch.randn((128, 10), device=device)

    optimizer.zero_grad()
    output = model(data)
    loss = loss_fn(output, target)
    loss.backward()

    optimizer.step()
    xm.mark_step()

  # Print mean parameters so we can confirm they're the same across replicas
  print([p.mean() for p in model.parameters()])

if __name__ == '__main__':
-  os.environ['XRT_TPU_CONFIG'] = 'localservice;0;localhost:51011'
-  os.environ['MASTER_ADDR'] = 'localhost'
-  os.environ['MASTER_PORT'] = '12355'

+  # Recommended: set PJRT_DEVICE to your local device type
+  os.environ['PJRT_DEVICE'] = 'TPU'

  torch_xla.launch(_mp_fn)

优势

  • 简单的运行时配置:只需将 PJRT_DEVICE 设置为 TPUCPUCUDA,即可开始使用 XLA!或者,让 PJRT 根据您的环境自动选择设备。

  • 性能提升:减少 gRPC 开销意味着更快的端到端执行。在 TorchBench 2.0 上,我们在 TPU v4 上观察到训练时间提升了 >35%。

  • 简单的 pod 执行:只需将代码复制到每个 TPU worker,然后使用 gcloud compute tpus tpuvm ssh --worker=all 同时执行它们。

  • 更好的扩展性:消除了 XRT 对参数大小的限制,并支持高达 2048 个 TPU 芯片。

快速入门

要开始使用 PJRT 与 PyTorch/XLA,您只需设置 PJRT_DEVICE 环境变量。如果您正在使用 TPU v2 或 v3,请继续阅读以了解 TPU v2、v3 和 v4 之间的差异。

CPU

在安装了 PyTorch/XLA 的任何机器上,您可以像这样在 CPU 上运行我们的 MNIST 示例:

PJRT_DEVICE=CPU python3 xla/test/test_train_mp_mnist.py --fake_data

TPU

要创建已安装 PyTorch/XLA r2.0 的新 TPU:

gcloud alpha compute tpus tpu-vm create $USER-pjrt --accelerator-type=v4-8 --version=tpu-vm-v4-pt-2.0 --zone=us-central2-b --project=$PROJECT

在 v4-8 上,您可以像这样运行我们的 ResNet50 示例:

git clone --depth=1 --branch r2.0 https://github.com/pytorch/xla.git
PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=256 --num_epochs=1

默认情况下,PJRT 将使用所有 TPU 芯片。要仅使用一个 TPU 芯片,请配置 TPU_PROCESS_BOUNDSTPU_VISIBLE_CHIPS

TPU_PROCESS_BOUNDS=1,1,1 TPU_VISIBLE_CHIPS=0 PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=256 --num_epochs=1

Pods

在 TPU Pods 上,使用 gcloud 在每个 TPU 上并行运行您的命令:

gcloud alpha compute tpus tpu-vm ssh $USER-pjrt --zone=us-central2-b --project=$PROJECT --worker=all --command="git clone --depth=1 --branch r1.13 https://github.com/pytorch/xla.git"
gcloud alpha compute tpus tpu-vm ssh $USER-pjrt --zone=us-central2-b --project=$PROJECT --worker=all --command="PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=256 --num_epochs=1"

Docker

您也可以使用 Docker 在预装了 PyTorch/XLA 的容器中运行您的工作负载:

export DOCKER_IMAGE=gcr.io/...

# Optional: authenticate docker if your image is in a private GCP repository
gcloud compute tpus tpu-vm ssh $USER-pjrt --zone=us-central2-b --project=$PROJECT --worker=all --command "sudo gcloud auth configure-docker"

# Run your workload
gcloud compute tpus tpu-vm ssh $USER-pjrt --zone=us-central2-b --project=$PROJECT --worker=all --command "sudo docker run --rm --privileged --net=host -e PJRT_DEVICE=TPU $DOCKER_IMAGE python pytorch/xla/test/test_train_mp_imagenet.py --fake_data"

请注意,docker run 需要对主机具有特权访问权限 (--privileged),以便将 TPU 设备暴露给容器。目前,TPU pods 上的 Docker 仅支持主机网络 --net=host。有关更多信息,请参阅 Cloud TPU 文档

GPU

单节点 GPU 训练

要将 PJRT 与 GPU 一起使用,只需设置 PJRT_DEVICE=CUDA 并将 GPU_NUM_DEVICES 配置为主机上的设备数量。例如:

PJRT_DEVICE=CUDA GPU_NUM_DEVICES=4 python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=128 --num_epochs=1

您也可以使用 torchrun 启动单节点多 GPU 训练。例如:

PJRT_DEVICE=CUDA torchrun --nnodes 1 --nproc-per-node ${NUM_GPU_DEVICES} xla/test/test_train_mp_imagenet.py --fake_data --pjrt_distributed --batch_size=128 --num_epochs=1

在上述示例中,--nnodes 表示要使用的机器数量(物理机器或虚拟机)(由于我们进行单节点训练,因此为 1)。--nproc-per-node 表示要使用的 GPU 设备数量。

多节点 GPU 训练

请注意,此功能仅适用于 cuda 12+。与 PyTorch 使用多节点训练类似,您可以按如下方式运行命令:

PJRT_DEVICE=CUDA torchrun \
--nnodes=${NUMBER_GPU_VM} \
--node_rank=${CURRENT_NODE_RANK} \
--nproc_per_node=${NUMBER_LOCAL_GPU_DEVICES} \
--rdzv_endpoint=<internal_ip_address:port> multinode_training.py
  • --nnodes: 要使用的 GPU 机器数量。

  • --node_rank: 当前 GPU 机器的索引。值可以为 0, 1, ..., ${NUMBER_GPU_VM}-1。

  • --nproc_per_node: 当前机器上要使用的 GPU 设备数量。

  • --rdzv_endpoint: node_rank==0 的 GPU 机器的端点,格式为 host:porthost 将是内部 IP 地址。port 可以是机器上的任何可用端口。对于单节点训练/推理,此参数可以省略。

例如,如果您想在 2 台 GPU 机器上训练:machine_0 和 machine_1,在第一台 GPU 机器 machine_0 上运行:

# PJRT_DEVICE=CUDA torchrun \
--nnodes=2 \
--node_rank=0 \
--nproc_per_node=4 \
--rdzv_endpoint="<MACHINE_0_INTERNAL_IP_ADDRESS>:12355" pytorch/xla/test/test_train_mp_imagenet.py  --fake_data --pjrt_distributed --batch_size=128 --num_epochs=1

在第二台 GPU 机器上运行:

# PJRT_DEVICE=CUDA torchrun \
--nnodes=2 \
--node_rank=1 \
--nproc_per_node=4 \
--rdzv_endpoint="<MACHINE_0_INTERNAL_IP_ADDRESS>:12355" pytorch/xla/test/test_train_mp_imagenet.py  --fake_data --pjrt_distributed --batch_size=128 --num_epochs=1

上面两个命令之间的区别在于 --node_rank,如果您想在每台机器上使用不同数量的 GPU 设备,也可能包括 --nproc_per_node。其余部分完全相同。有关 torchrun 的更多信息,请参阅此页面

与 XRT 的区别

尽管在大多数情况下,我们期望 PJRT 和 XRT 从最终用户的角度来看几乎可以互换使用(尤其是在 TPU v4 上),但仍有一些微妙的差异需要牢记。重要的是,XRT 是围绕 TPU Node 架构设计的,因此即使在 TPU VM 上,它也始终会生成一个客户端和一个服务器进程。因此,每批输入数据在通过网络发送时都需要额外的序列化和反序列化延迟。

PJRT 直接使用本地设备,无需中间服务器进程。在默认配置下,PJRT 将为每个 TPU 芯片创建一个进程,即每个 TPU 主机 4 个进程。有关 TPU 架构的更多信息,请参阅 Cloud TPU 文档

  • 对于受开销限制的工作负载,可能获得性能提升。

  • 在 XRT 下,服务器进程是唯一与 TPU 设备交互的进程,客户端进程无法直接访问 TPU 设备。当分析单主机 TPU(例如 v3-8 或 v4-8)时,您通常会看到 8 个设备跟踪(每个 TPU core 一个)。使用 PJRT,每个进程有一个芯片,该进程的分析将仅显示 2 个 TPU core。

    • 出于同样的原因,分析功能在 XRT 的 TPU Pods 上不起作用,因为服务器进程独立于用户的模型代码运行。PJRT 没有这个限制,因此可以在 TPU Pod 中对每个进程的 2 个 TPU core 进行分析。

  • PJRT 仅支持 TPU VM 架构,我们没有计划使用 PJRT 支持 TPU Node 架构。

  • PJRT 的运行时配置显著简化。xla_dist 不是运行 TPU Pod 工作负载所必需的。相反,将您的代码复制到每个 TPU 主机 ([gcloud compute tpus tpu-vm   scp](https://cloud.google.com/sdk/gcloud/reference/alpha/compute/tpus/tpu-vm/scp)),然后在每个主机上并行运行代码(例如 [gcloud compute tpus tpu-vm   ssh --workers=all --command="PJRT_DEVICE=TPU python   run.py"](https://cloud.google.com/sdk/gcloud/reference/alpha/compute/tpus/tpu-vm/ssh))。

  • xm.rendezvous 已使用 XLA 原生集合通信重新实现,以增强在大型 TPU Pods 上的稳定性。有关更多详细信息,请参见下文。

TPU v2/v3 上的多线程

在 TPU v2 和 v3 上,分布式工作负载总是以多线程方式运行,因为每个 TPU core 都将两个 TPU core 作为设备暴露,并且一次只能有一个进程打开一个 TPU 芯片。在其默认配置中,xmp.spawn 会自动生成尽可能多的进程(每个 TPU 主机 4 个),并为每个进程创建两个线程(每个 TPU core 一个)。

注意:在 TPU v4 上,每个 TPU 芯片被表示为一个 PyTorch 设备,因此分布式工作负载将运行在 4 个进程上,每个进程只有一个线程。这与 XRT 的行为相同。

在大多数情况下,这不需要对您现有代码进行实质性更改。大多数情况下您需要进行的主要更改是模型初始化。因为 torch 的全局 RNG 在线程之间共享,即使您在每个副本中将 torch.manual_seed 设置为相同的值,不同线程和不同运行之间的结果也会有所不同。为了在副本之间获得一致的参数,您可以使用 torch_xla.experimental.pjrt.broadcast_master_param 将一个副本的参数广播到所有其他副本,或者从一个共同的检查点加载每个副本的参数。

xm.rendezvous 的更改

PyTorch/XLA r2.0 中的新内容

使用 XRT 时,worker 0 运行一个 mesh master 服务,所有 worker 上的所有进程都通过 gRPC 连接到该服务。在实践中,我们发现由于连接到 worker 0 的入站连接数量庞大,在拥有数千个芯片的 TPU pods 上运行单个 mesh master 进程是不可靠的。单个客户端进程超时可能导致故障并强制整个工作负载重启。

因此,我们使用 XLA 原生集合通信重新实现了 xm.rendezvous,这在大型 TPU pods 上更加稳定且经过充分测试。与 XRT 实现相比,这带来了两个新的限制:

  • 由于有效载荷必须成为 XLA 图的一部分,因此在数据传输之前和之后都会调用 xm.mark_step。在模型代码中间调用 xm.rendezvous 可能会强制进行不必要的编译。

  • 由于 XLA 不允许集合操作在部分 worker 上运行,所有 worker 都必须参与 rendezvous

如果您需要 xm.rendezvous 的旧行为(即在不改变 XLA 图的情况下通信数据和/或同步部分 worker),请考虑使用 torch.distributed.barrier torch.distributed.all_gather_object gloo 进程组。如果您同时使用 xla torch.distributed 后端,您可以使用 torch.new_group 来创建 gloo 子组。请参见 PyTorch 文档中的 此示例。请记住以下限制:

  • 在 TPU v2/v3 上,torch.distributed 不完全支持。仅实现了 xla 后端的部分操作,并且 gloo 在多线程上下文中可能无法按预期工作。

  • 在我们的实验中,gloo 在扩展到数千个 TPU 芯片时表现不佳,因此在大规模使用时,这种替代方案的可靠性不如使用 PJRT 的 xm.rendezvous

PJRT 和 torch.distributed

PyTorch/XLA r2.0 中的新内容

将 PJRT 与 torch.distributed[torch.nn.parallel.DistributedDataParallel](https://github.com/pytorch/xla/blob/master/docs/ddp.md) 一起使用时,我们强烈建议使用新的 xla:// init_method,它通过查询运行时自动查找副本 ID、世界大小和 master IP。例如:

import torch
import torch_xla
import torch.distributed as dist
import torch_xla.core.xla_model as xm
from torch_xla.experimental import pjrt

# Required for `xla://` init_method and `xla` backend
import torch_xla.distributed.xla_backend

def _all_gather(index: int):
  # No need to pass in `rank` or `world_size`
  dist.init_process_group('xla', init_method='xla://')

  t = torch.tensor([index], dtype=torch.int32, device=xm.xla_device())
  output = [torch.zeros_like(t) for _ in range(dist.get_world_size())]
  dist.all_gather(output, t)

  xm.mark_step()
  print(output)

if __name__ == '__main__':
  torch_xla.launch(_all_gather)

注意:尽管在 TPU v4 上不是必需的,但仍建议使用 xla:// init_method。如果使用 env://MASTER_ADDR 必须设置为拥有设备 0 的 IP 主机,这不总是 worker 0。xla:// init_method 会自动查找此 IP。

注意:对于 TPU v2/v3,您仍然需要导入 torch_xla.experimental.pjrt_backend,因为 torch.distributed 中对 TPU v2/v3 的支持仍处于实验阶段。

有关在 PyTorch/XLA 上使用 DistributedDataParallel 的更多信息,请参阅 TPU V4 上的 ddp.md。有关结合使用 DDP 和 PJRT 的示例,请在 TPU 上运行以下 示例脚本

PJRT_DEVICE=TPU python xla/test/test_train_mp_mnist.py --ddp --pjrt_distributed --fake_data --num_epochs 1

性能

与 XRT 相比,TorchBench 显示 PJRT 在各项任务的平均训练时间上都有所改进,在 TPU v4-8 上平均提升超过 35%。具体提升效果因任务和模型类型而异,范围在 0% 到 175% 之间。以下图表显示了按任务细分的结果:

PJRT vs XRT

新 TPU 运行时

PyTorch/XLA r2.0 中的新内容

PyTorch/XLA r2.0 版本引入了对 PJRT Plugin API 的支持,用于访问 libtpu 中新的基于 TFRT 的 TPU 运行时。当设置 PJRT_DEVICE=TPU 时,这现在是默认运行时。1.13 中使用的基于 StreamExecutor 的旧版 TPU 运行时在 2.0 版本中仍可通过设置 PJRT_DEVICE=TPU_LEGACY 使用,但将在未来版本中移除。如果您遇到仅在 TPU 而不在 TPU_LEGACY 上发生的问题,请在 GitHub 上提交一个问题。

在大多数情况下,我们期望两种运行时的性能相似,但在某些情况下,新运行时可能会快达 30%。以下图表显示了按任务细分的结果:

TFRT vs StreamExecutor

注意:此图表中显示的提升也包含在 PJRT 与 XRT 的比较中。

文档

访问 PyTorch 的完整开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得问题解答

查看资源