快捷方式

PJRT 运行时

PyTorch/XLA 已从基于 TensorFlow 的 XRT 运行时迁移到 PJRT 运行时JAX 使用了该运行时。

如果您遇到 PJRT 的错误,请在 GitHub 上提交问题,并添加 runtime 标签。

PyTorch/XLA r2.1 中的新功能:

  • PJRT 在 PyTorch/XLA r2.1 中是稳定的!

  • 公共运行时 API 已从 torch_xla.experimental.pjrt 移动到 torch_xla.runtime

    • pjrt:// 初始化方法已重命名为 xla://,并且由 torch_xla.distributed.xla_backend 注册。

    • 以前的 torch_xla.experimental.* 名称在此版本中仍然可用,以实现兼容性。

  • 使用 init_method='xla://' 时,现在支持 torchrun

  • 通过 PJRT C API 为 XPU 和 Neuron 提供新的插件。

PyTorch/XLA r2.0 中的新功能:

  • 如果您不传入任何其他运行时配置,PJRT 将默认配置。如果您继续设置 XRT 配置 (XRT_TPU_CONFIG),则此更改没有影响

  • libtpu 中的新 TPU 运行时实现将性能提高了高达 30%。

  • 新的 xm.rendezvous 实现可扩展到数千个 TPU 核心

  • [实验性] torch.distributed 支持 TPU v2 和 v3,包括 pjrt:// init_method

TL;DR

  • 要使用 PJRT 预览运行时,请将 PJRT_DEVICE 环境变量设置为 CPUTPUCUDA

  • 在 XRT 中,所有分布式工作负载都是多进程的,每个设备一个进程。在 PJRT 中的 TPU v2 和 v3 上,工作负载是多进程和多线程的(4 个进程,每个进程 2 个线程),因此您的工作负载应该是线程安全的。请参阅 TPU v2/v3 上的多线程API 指南的多进程部分 以获取更多信息。需要记住的关键区别

    • 要以线程安全的方式初始化模型,可以在初始化后跨副本广播参数 (torch_xla.experimental.pjrt.broadcast_master_param) 或从公共检查点加载每个副本的参数。

    • 对于其他随机数生成,请尽可能使用 torch.Generator。即使您在所有副本中设置相同的 torch.manual_seed,全局 torch RNG 也不是线程安全的。

    • 要使用 torch.distributed,请导入 torch_xla.experimental.pjrt_backend 并使用 xla:// init_method

    • 这些步骤对于 GPU 和 TPU v4 是可选的。

从 XRT 到 PJRT 的示例差异

import os

import torch
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.optim as optim
import torch.distributed as dist
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_backend
+import torch_xla.runtime as xr


def _mp_fn(index):
  device = xm.xla_device()
-  dist.init_process_group('xla', rank=xr.global_ordinal(), world_size=xr.world_size())
+  dist.init_process_group('xla', init_method='xla://')

  torch.manual_seed(42)
  model = nn.Linear(128, 10).to(device)

+  # Optional for TPU v4 and GPU
+  xm.broadcast_master_param(model)
  model = DDP(model, gradient_as_bucket_view=True)

  loss_fn = nn.MSELoss()
  optimizer = optim.SGD(model.parameters(), lr=.001)

  for i in range(10):
    data, target = torch.randn((128, 128), device=device), torch.randn((128, 10), device=device)

    optimizer.zero_grad()
    output = model(data)
    loss = loss_fn(output, target)
    loss.backward()

    optimizer.step()
    xm.mark_step()

  # Print mean parameters so we can confirm they're the same across replicas
  print([p.mean() for p in model.parameters()])

if __name__ == '__main__':
-  os.environ['XRT_TPU_CONFIG'] = 'localservice;0;localhost:51011'
-  os.environ['MASTER_ADDR'] = 'localhost'
-  os.environ['MASTER_PORT'] = '12355'

+  # Recommended: set PJRT_DEVICE to your local device type
+  os.environ['PJRT_DEVICE'] = 'TPU'

  torch_xla.launch(_mp_fn)

优势

  • 简单的运行时配置:只需将 PJRT_DEVICE 设置为 TPUCPUCUDA,然后开始使用 XLA!或者,让 PJRT 根据您的环境自动选择设备。

  • 性能提升:减少 gRPC 的开销意味着更快的端到端执行。在 TorchBench 2.0 上,我们观察到 TPU v4 上的训练时间提高了 >35%。

  • 轻松的 Pod 执行:只需将您的代码复制到每个 TPU 工作器,并使用 gcloud compute tpus tpuvm ssh --worker=all 同时执行它们。

  • 更好的扩展性:消除了 XRT 对参数大小的限制,并支持多达 2048 个 TPU 芯片。

快速入门

要开始将 PJRT 与 PyTorch/XLA 一起使用,您只需设置 PJRT_DEVICE 环境变量。如果您正在使用 TPU v2 或 v3,请继续阅读以了解 TPU v2 和 v3 以及 v4 之间的差异。

CPU

在安装了 PyTorch/XLA 的任何机器上,您都可以像这样在 CPU 上运行我们的 MNIST 示例

PJRT_DEVICE=CPU python3 xla/test/test_train_mp_mnist.py --fake_data

TPU

要创建安装了 PyTorch/XLA r2.0 的新 TPU

gcloud alpha compute tpus tpu-vm create $USER-pjrt --accelerator-type=v4-8 --version=tpu-vm-v4-pt-2.0 --zone=us-central2-b --project=$PROJECT

在 v4-8 上,您可以像这样运行我们的 ResNet50 示例

git clone --depth=1 --branch r2.0 https://github.com/pytorch/xla.git
PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=256 --num_epochs=1

默认情况下,PJRT 将使用所有 TPU 芯片。要仅使用一个 TPU 芯片,请配置 TPU_PROCESS_BOUNDSTPU_VISIBLE_CHIPS

TPU_PROCESS_BOUNDS=1,1,1 TPU_VISIBLE_CHIPS=0 PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=256 --num_epochs=1

Pod

在 TPU Pod 上,使用 gcloud 在每个 TPU 上并行运行您的命令

gcloud alpha compute tpus tpu-vm ssh $USER-pjrt --zone=us-central2-b --project=$PROJECT --worker=all --command="git clone --depth=1 --branch r1.13 https://github.com/pytorch/xla.git"
gcloud alpha compute tpus tpu-vm ssh $USER-pjrt --zone=us-central2-b --project=$PROJECT --worker=all --command="PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=256 --num_epochs=1"

Docker

您还可以使用 Docker 在预装了 PyTorch/XLA 的容器中运行您的工作负载

export DOCKER_IMAGE=gcr.io/...

# Optional: authenticate docker if your image is in a private GCP repository
gcloud compute tpus tpu-vm ssh $USER-pjrt --zone=us-central2-b --project=$PROJECT --worker=all --command "sudo gcloud auth configure-docker"

# Run your workload
gcloud compute tpus tpu-vm ssh $USER-pjrt --zone=us-central2-b --project=$PROJECT --worker=all --command "sudo docker run --rm --privileged --net=host -e PJRT_DEVICE=TPU $DOCKER_IMAGE python pytorch/xla/test/test_train_mp_imagenet.py --fake_data"

请注意,docker run 需要对主机 (--privileged) 的特权访问权限,才能将 TPU 设备暴露给容器。目前,TPU pod 上的 Docker 仅支持主机网络 --net=host。有关更多信息,请参阅 Cloud TPU 文档

GPU

单节点 GPU 训练

要将 GPU 与 PJRT 一起使用,只需设置 PJRT_DEVICE=CUDA 并配置 GPU_NUM_DEVICES 为主机上的设备数量。例如

PJRT_DEVICE=CUDA GPU_NUM_DEVICES=4 python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=128 --num_epochs=1

您还可以使用 torchrun 来启动单节点多 GPU 训练。例如,

PJRT_DEVICE=CUDA torchrun --nnodes 1 --nproc-per-node ${NUM_GPU_DEVICES} xla/test/test_train_mp_imagenet.py --fake_data --pjrt_distributed --batch_size=128 --num_epochs=1

在上面的示例中,--nnodes 表示要使用多少台机器(物理机或虚拟机)(由于我们进行单节点训练,因此为 1)。--nproc-per-node 表示要使用多少个 GPU 设备。

多节点 GPU 训练

请注意,此功能仅适用于 cuda 12+。与 PyTorch 使用多节点训练的方式类似,您可以按如下方式运行命令

PJRT_DEVICE=CUDA torchrun \
--nnodes=${NUMBER_GPU_VM} \
--node_rank=${CURRENT_NODE_RANK} \
--nproc_per_node=${NUMBER_LOCAL_GPU_DEVICES} \
--rdzv_endpoint=<internal_ip_address:port> multinode_training.py
  • --nnodes:要使用的 GPU 机器数量。

  • --node_rank:当前 GPU 机器的索引。该值可以是 0, 1, …, ${NUMBER_GPU_VM}-1。

  • --nproc_per_node:当前机器上要使用的 GPU 设备数量。

  • --rdzv_endpoint:node_rank==0 的 GPU 机器的端点,格式为 host:porthost 将是内部 IP 地址。port 可以是机器上的任何可用端口。对于单节点训练/推理,可以省略此参数。

例如,如果您想在 2 台 GPU 机器上进行训练:machine_0 和 machine_1,在第一台 GPU 机器 machine_0 上,运行

# PJRT_DEVICE=CUDA torchrun \
--nnodes=2 \
--node_rank=0 \
--nproc_per_node=4 \
--rdzv_endpoint="<MACHINE_0_INTERNAL_IP_ADDRESS>:12355" pytorch/xla/test/test_train_mp_imagenet.py  --fake_data --pjrt_distributed --batch_size=128 --num_epochs=1

在第二台 GPU 机器上,运行

# PJRT_DEVICE=CUDA torchrun \
--nnodes=2 \
--node_rank=1 \
--nproc_per_node=4 \
--rdzv_endpoint="<MACHINE_0_INTERNAL_IP_ADDRESS>:12355" pytorch/xla/test/test_train_mp_imagenet.py  --fake_data --pjrt_distributed --batch_size=128 --num_epochs=1

以上 2 个命令之间的区别在于 --node_rank 以及可能的 --nproc_per_node(如果您想在每台机器上使用不同数量的 GPU 设备)。所有其余部分都是相同的。有关 torchrun 的更多信息,请参阅此页面

与 XRT 的差异

虽然在大多数情况下,我们预计 PJRT 和 XRT 在最终用户的角度来看可以互换使用(尤其是在 TPU v4 上),但仍有一些细微的差异需要牢记。重要的是,XRT 是围绕 TPU 节点架构设计的,因此即使在 TPU VM 上,它也始终会生成客户端和服务器进程。因此,每批输入都会因序列化和反序列化数据以通过网络发送数据而产生额外的延迟。

PJRT 直接使用本地设备,没有中间服务器进程。在默认配置中,PJRT 将为每个 TPU 芯片创建一个进程,或为每个 TPU 主机创建 4 个进程。有关 TPU 架构的更多信息,请参阅 Cloud TPU 文档

  • 对于受开销限制的工作负载,性能提升是可能的。

  • 在 XRT 下,服务器进程是唯一与 TPU 设备交互的进程,客户端进程无法直接访问 TPU 设备。在分析单主机 TPU(例如 v3-8 或 v4-8)时,您通常会看到 8 个设备跟踪(每个 TPU 核心一个)。使用 PJRT,每个进程都有一个芯片,并且来自该进程的配置文件将仅显示 2 个 TPU 核心。

    • 出于同样的原因,分析在带有 XRT 的 TPU Pod 上不起作用,因为服务器进程独立于用户的模型代码运行。PJRT 没有该限制,因此可以在 TPU Pod 中为每个进程分析 2 个 TPU 核心。

  • PJRT 仅支持 TPU VM 架构,我们没有计划在 PJRT 中支持 TPU 节点架构。

  • 使用 PJRT,运行时配置明显更简单。xla_dist 不是运行 TPU Pod 工作负载所必需的。相反,将您的代码复制到每个 TPU 主机 ([gcloud compute tpus tpu-vm   scp](https://cloud.google.com/sdk/gcloud/reference/alpha/compute/tpus/tpu-vm/scp)) 并在每个主机上并行运行代码 (例如 [gcloud compute tpus tpu-vm   ssh --workers=all --command="PJRT_DEVICE=TPU python   run.py"](https://cloud.google.com/sdk/gcloud/reference/alpha/compute/tpus/tpu-vm/ssh))

  • xm.rendezvous 已使用 XLA 原生集合通信重新实现,以增强大型 TPU pod 的稳定性。有关更多详细信息,请参见下文。

TPU v2/v3 上的多线程

在 TPU v2 和 v3 上,分布式工作负载始终以多线程方式运行,因为每个 TPU 核心将两个 TPU 核心作为设备公开,并且一次只能有一个进程打开一个 TPU 芯片。在其默认配置中,xmp.spawn 会自动生成尽可能多的进程(每个 TPU 主机 4 个),并为每个进程创建两个线程(每个 TPU 核心一个)。

注意:在 TPU v4 上,每个 TPU 芯片都表示为一个 PyTorch 设备,因此分布式工作负载将在 4 个进程之间运行,每个进程只有一个线程。这与 XRT 的行为相同。

在大多数情况下,这不需要对现有代码进行重大更改。在大多数情况下,您必须进行的主要更改是模型初始化。由于 torch 的全局 RNG 在线程之间共享,因此即使您在每个副本中将 torch.manual_seed 设置为相同的值,线程之间和运行之间的结果也会有所不同。要在副本之间获得一致的参数,请使用 torch_xla.experimental.pjrt.broadcast_master_param 将一个副本的参数广播到所有其他副本,或从公共检查点加载每个副本的参数。

xm.rendezvous 的更改

PyTorch/XLA r2.0 中的新功能

使用 XRT,工作器 0 运行网格主服务,并且所有工作器上的所有进程都通过 gRPC 连接到该服务。在实践中,我们发现由于到工作器 0 的入站连接数量,在具有数千个芯片的 TPU pod 上运行单个网格主进程是不可靠的。单个客户端进程超时可能会导致故障并强制整个工作负载重新启动。

因此,我们已使用原生 XLA 集合通信重新实现了 xm.rendezvous,这在大型 TPU pod 上更加稳定且经过充分测试。与 XRT 实现相比,这施加了两个新的约束

  • 由于有效负载必须成为 XLA 图的一部分,因此在数据传输之前和之后都会调用 xm.mark_step。在模型代码中间调用 xm.rendezvous 可能会强制进行不必要的编译。

  • 由于 XLA 不允许集合操作在工作器的子集上运行,因此所有工作器都必须参与 rendezvous

如果您需要 xm.rendezvous 的旧行为(即,在不更改 XLA 图和/或同步工作器子集的情况下通信数据),请考虑使用 `torch.distributed.barrier <https://pytorch.ac.cn/docs/stable/distributed.html#torch.distributed.barrier&gt;[__ 或 ]{.title-ref}torch.distributed.all_gather_object <https://pytorch.ac.cn/docs/stable/distributed.html#torch.distributed.all_gather_object&gt;[__ 与 ]{.title-ref}[gloo]{.title-ref}[ 进程组。如果您还使用 ]{.title-ref}[xla]{.title-ref}[ ]{.title-ref}[torch.distributed]{.title-ref}[ 后端,则可以使用 ]{.title-ref}[torch.new*group]{.title-ref}[ 创建 ]{.title-ref}[gloo]{.title-ref}[ 子组。请参阅 `此示例 https://pytorch.ac.cn/docs/stable/distributed.html#monitored-barrier]{.title-ref}*_,来自 PyTorch 文档。请记住以下约束

  • torch.distributed 在 TPU v2/v3 上未完全支持。仅实现了具有 xla 后端的运算子集,并且 gloo 可能在多线程上下文中无法按预期工作。

  • 在我们的实验中,gloo 无法很好地扩展到数千个 TPU 芯片,因此预计此替代方案在大型规模下不如使用带有 PJRT 的 xm.rendezvous 可靠。

PJRT 和 torch.distributed

PyTorch/XLA r2.0 中的新功能

将 PJRT 与 torch.distributed[torch.nn.parallel.DistributedDataParallel](https://github.com/pytorch/xla/blob/master/docs/ddp.md) 一起使用时,我们强烈建议使用新的 xla:// init_method,它通过查询运行时自动查找副本 ID、世界大小和主 IP。例如

import torch
import torch_xla
import torch.distributed as dist
import torch_xla.core.xla_model as xm
from torch_xla.experimental import pjrt

# Required for `xla://` init_method and `xla` backend
import torch_xla.distributed.xla_backend

def _all_gather(index: int):
  # No need to pass in `rank` or `world_size`
  dist.init_process_group('xla', init_method='xla://')

  t = torch.tensor([index], dtype=torch.int32, device=xm.xla_device())
  output = [torch.zeros_like(t) for _ in range(dist.get_world_size())]
  dist.all_gather(output, t)

  xm.mark_step()
  print(output)

if __name__ == '__main__':
  torch_xla.launch(_all_gather)

注意:虽然在 TPU v4 上不需要 xla:// init_method,但仍然建议使用。如果您使用 env://,则 MASTER_ADDR 必须设置为具有设备 0 的 IP 主机,而该主机总是工作器 0。xla:// init_method 会自动查找此 IP。

注意:对于 TPU v2/v3,您仍然需要导入 torch_xla.experimental.pjrt_backend,因为 torch.distributed 中的 TPU v2/v3 支持仍处于实验阶段。

有关在 PyTorch/XLA 上使用 DistributedDataParallel 的更多信息,请参阅 TPU V4 上的 ddp.md。有关将 DDP 和 PJRT 一起使用的示例,请在 TPU 上运行以下 示例脚本

PJRT_DEVICE=TPU python xla/test/test_train_mp_mnist.py --ddp --pjrt_distributed --fake_data --num_epochs 1

性能

TorchBench 显示,与 XRT 相比,PJRT 在各项任务中的平均训练时间有所改善,TPU v4-8 上的平均改进超过 35%。收益因任务和模型类型而异,范围从 0% 到 175%。下图显示了按任务的细分

PJRT vs XRT

新的 TPU 运行时

PyTorch/XLA r2.0 中的新功能

PyTorch/XLA r2.0 版本引入了对 PJRT 插件 API 的支持,该 API 用于访问 libtpu 中新的基于 TFRT 的 TPU 运行时。当设置 PJRT_DEVICE=TPU 时,这现在是默认运行时。在 1.13 中使用的基于遗留 StreamExecutor 的 TPU 运行时仍然在 2.0 版本中可通过 PJRT_DEVICE=TPU_LEGACY 获得,但将在未来的版本中删除。如果您遇到仅在 TPU 上发生而不是在 TPU_LEGACY 上发生的问题,请在 GitHub 上提交问题。

在大多数情况下,我们预计两个运行时之间的性能相似,但在某些情况下,新运行时可能会快 30%。下图显示了按任务的细分

TFRT vs StreamExecutor

注意:此图表中显示的改进也包含在 PJRT 与 XRT 的比较中。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源