快捷键

PJRT 运行时

PyTorch/XLA 已从基于 TensorFlow 的 XRT 运行时迁移到 PJRT 运行时,该运行时由 JAX 使用。

如果遇到 PJRT 的错误,请在 GitHub 上提交问题,并添加 runtime 标签。

PyTorch/XLA r2.1 中的新功能:

  • PJRT 在 PyTorch/XLA r2.1 中稳定!

  • 公共运行时 API 已从 torch_xla.experimental.pjrt 移动到 torch_xla.runtime

    • pjrt:// 初始化方法已重命名为 xla://,它由 torch_xla.distributed.xla_backend 注册。

    • 为了兼容性,此版本中仍然可以使用以前的 torch_xla.experimental.* 名称。

  • torchrun 现在支持使用 init_method='xla://'

  • 通过 PJRT C API 为 XPU 和 Neuron 提供新的插件。

PyTorch/XLA r2.0 中的新功能:

  • 如果您不传入任何其他运行时配置,PJRT 将默认配置。如果您继续设置 XRT 配置 (XRT_TPU_CONFIG),则此更改不会产生任何影响

  • libtpu 中的新 TPU 运行时实现将性能提高了高达 30%。

  • 新的 xm.rendezvous 实现可以扩展到数千个 TPU 内核

  • [实验性] torch.distributed 支持 TPU v2 和 v3,包括 pjrt:// init_method

TL;DR

  • 要使用 PJRT 预览运行时,请将 PJRT_DEVICE 环境变量设置为 CPUTPUCUDA

  • 在 XRT 中,所有分布式工作负载都是多进程的,每个设备一个进程。在 PJRT 中的 TPU v2 和 v3 上,工作负载是多进程和多线程的(4 个进程,每个进程 2 个线程),因此您的工作负载应是线程安全的。有关更多信息,请参阅 TPU v2/v3 上的多线程API 指南中的多进程部分。请记住关键区别

    • 要以线程安全的方式初始化模型,请在初始化后将参数广播到所有副本 (torch_xla.experimental.pjrt.broadcast_master_param),或者从公共检查点加载每个副本的参数。

    • 对于其他随机数生成,请尽可能使用 torch.Generator。全局 torch RNG 不是线程安全的,即使您在所有副本中设置了相同的 torch.manual_seed

    • 要使用 torch.distributed,请导入 torch_xla.experimental.pjrt_backend 并使用 xla:// init_method

    • 对于 GPU 和 TPU v4,这些步骤是可选的。

从 XRT 到 PJRT 的示例差异

 import os

 import torch
 import torch.nn as nn
 from torch.nn.parallel import DistributedDataParallel as DDP
 import torch.optim as optim
 import torch.distributed as dist
 import torch_xla.core.xla_model as xm
 import torch_xla.distributed.parallel_loader as pl
 import torch_xla.distributed.xla_backend
 import torch_xla.distributed.xla_multiprocessing as xmp
+import torch_xla.runtime as xr


 def _mp_fn(index):
   device = xm.xla_device()
-  dist.init_process_group('xla', rank=xm.get_ordinal(), world_size=xm.xrt_world_size())
+  dist.init_process_group('xla', init_method='xla://')

   torch.manual_seed(42)
   model = nn.Linear(128, 10).to(device)

+  # Optional for TPU v4 and GPU
+  xm.broadcast_master_param(model)
   model = DDP(model, gradient_as_bucket_view=True)

   loss_fn = nn.MSELoss()
   optimizer = optim.SGD(model.parameters(), lr=.001)

   for i in range(10):
     data, target = torch.randn((128, 128), device=device), torch.randn((128, 10), device=device)

     optimizer.zero_grad()
     output = model(data)
     loss = loss_fn(output, target)
     loss.backward()

     optimizer.step()
     xm.mark_step()

   # Print mean parameters so we can confirm they're the same across replicas
   print([p.mean() for p in model.parameters()])

 if __name__ == '__main__':
-  os.environ['XRT_TPU_CONFIG'] = 'localservice;0;localhost:51011'
-  os.environ['MASTER_ADDR'] = 'localhost'
-  os.environ['MASTER_PORT'] = '12355'

+  # Recommended: set PJRT_DEVICE to your local device type
+  os.environ['PJRT_DEVICE'] = 'TPU'

   xmp.spawn(_mp_fn)

优势

  • 简单的运行时配置:只需将 PJRT_DEVICE 设置为 TPUCPUCUDA 即可开始使用 XLA!或者,让 PJRT 根据您的环境自动选择设备。

  • 性能提升:gRPC 开销减少,意味着更快的端到端执行。在 TorchBench 2.0 上,我们观察到 TPU v4 上的训练时间提高了 >35%。

  • 轻松执行 pod:只需将您的代码复制到每个 TPU 工作程序,并使用 gcloud compute tpus tpuvm ssh --worker=all 并行执行它们。

  • 更好的可扩展性:消除了 XRT 对参数大小的限制,并支持高达 2048 个 TPU 芯片。

快速入门

要开始使用 PyTorch/XLA 中的 PJRT,您只需要设置 PJRT_DEVICE 环境变量。如果您在 TPU v2 或 v3 上工作,请继续阅读以了解 TPU v2 和 v3 与 v4 之间的差异。

CPU

在安装了 PyTorch/XLA 的任何机器上,您可以像这样在 CPU 上运行我们的 MNIST 示例

PJRT_DEVICE=CPU python3 xla/test/test_train_mp_mnist.py --fake_data

TPU

要使用 PyTorch/XLA r2.0 创建新的 TPU

gcloud alpha compute tpus tpu-vm create $USER-pjrt --accelerator-type=v4-8 --version=tpu-vm-v4-pt-2.0 --zone=us-central2-b --project=$PROJECT

在 v4-8 上,您可以像这样运行我们的 ResNet50 示例

git clone --depth=1 --branch r2.0 https://github.com/pytorch/xla.git
PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=256 --num_epochs=1

默认情况下,PJRT 将使用所有 TPU 芯片。要仅使用一个 TPU 芯片,请配置 TPU_PROCESS_BOUNDSTPU_VISIBLE_CHIPS

TPU_PROCESS_BOUNDS=1,1,1 TPU_VISIBLE_CHIPS=0 PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=256 --num_epochs=1

Pods

在 TPU Pod 上,使用 gcloud 并行运行每个 TPU 上的命令

gcloud alpha compute tpus tpu-vm ssh $USER-pjrt --zone=us-central2-b --project=$PROJECT --worker=all --command="git clone --depth=1 --branch r1.13 https://github.com/pytorch/xla.git"
gcloud alpha compute tpus tpu-vm ssh $USER-pjrt --zone=us-central2-b --project=$PROJECT --worker=all --command="PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=256 --num_epochs=1"

Docker

您还可以使用 Docker 在预安装了 PyTorch/XLA 的容器中运行您的工作负载

export DOCKER_IMAGE=gcr.io/...

# Optional: authenticate docker if your image is in a private GCP repository
gcloud compute tpus tpu-vm ssh $USER-pjrt --zone=us-central2-b --project=$PROJECT --worker=all --command "sudo gcloud auth configure-docker"

# Run your workload
gcloud compute tpus tpu-vm ssh $USER-pjrt --zone=us-central2-b --project=$PROJECT --worker=all --command "sudo docker run --rm --privileged --net=host -e PJRT_DEVICE=TPU $DOCKER_IMAGE python pytorch/xla/test/test_train_mp_imagenet.py --fake_data"

请注意,docker run 需要对主机拥有特权访问权限 (--privileged) 才能将 TPU 设备公开给容器。目前,TPU pod 上的 Docker 仅支持主机网络 --net=host。有关更多信息,请参阅 Cloud TPU 文档

GPU

单节点 GPU 训练

要使用 PJRT 中的 GPU,只需将 PJRT_DEVICE=CUDA 设置为 GPU_NUM_DEVICES,并将配置设置为主机上的设备数量。例如

PJRT_DEVICE=CUDA GPU_NUM_DEVICES=4 python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=128 --num_epochs=1

您还可以使用 torchrun 启动单节点多 GPU 训练。例如,

PJRT_DEVICE=CUDA torchrun --nnodes 1 --nproc-per-node ${NUM_GPU_DEVICES} xla/test/test_train_mp_imagenet.py --fake_data --pjrt_distributed --batch_size=128 --num_epochs=1

在上面的示例中,--nnodes 表示要使用的机器(物理机或虚拟机)数量(由于我们进行的是单节点训练,因此为 1)。--nproc-per-node 表示要使用的 GPU 设备数量。

多节点 GPU 训练

请注意,此功能仅适用于 cuda 12+。与 PyTorch 如何使用多节点训练类似,您可以像下面这样运行命令

PJRT_DEVICE=CUDA torchrun \
--nnodes=${NUMBER_GPU_VM} \
--node_rank=${CURRENT_NODE_RANK} \
--nproc_per_node=${NUMBER_LOCAL_GPU_DEVICES} \
--rdzv_endpoint=<internal_ip_address:port> multinode_training.py
  • --nnodes: 要使用的 GPU 机器数量。

  • --node_rank: 当前 GPU 机器的索引。该值可以是 0、1、…、${NUMBER_GPU_VM}-1。

  • --nproc_per_node: 要在当前机器上使用的 GPU 设备数量。

  • –rdzv_endpoint: node_rank==0 的 GPU 机器的端点,格式为 host:port`. The``host 内部 IP 地址。 port` 可以是机器上的任何可用端口。对于单节点训练/推理,可以省略此参数。

例如,如果您想在两台 GPU 机器上进行训练:machine_0 和 machine_1,在第一台 GPU 机器 machine_0 上运行

# PJRT_DEVICE=CUDA torchrun \
--nnodes=2 \
--node_rank=0 \
--nproc_per_node=4 \
--rdzv_endpoint="<MACHINE_0_INTERNAL_IP_ADDRESS>:12355" pytorch/xla/test/test_train_mp_imagenet.py  --fake_data --pjrt_distributed --batch_size=128 --num_epochs=1

在第二台 GPU 机器上运行

# PJRT_DEVICE=CUDA torchrun \
--nnodes=2 \
--node_rank=1 \
--nproc_per_node=4 \
--rdzv_endpoint="<MACHINE_0_INTERNAL_IP_ADDRESS>:12355" pytorch/xla/test/test_train_mp_imagenet.py  --fake_data --pjrt_distributed --batch_size=128 --num_epochs=1

上面两个命令之间的区别是 --node_rank 和可能还有 --nproc_per_node,如果您想在每台机器上使用不同数量的 GPU 设备。其他所有内容都相同。有关 torchrun 的更多信息,请参阅此 页面

与 XRT 的区别

尽管在大多数情况下,我们预计 PJRT 和 XRT 从最终用户的角度来看可以互换使用(尤其是在 TPU v4 上),但仍有一些细微差别需要牢记。重要的是,XRT 是围绕 TPU 节点架构设计的,因此它始终会生成一个客户端和一个服务器进程,即使在 TPU VM 上也是如此。因此,每批输入都会因为通过网络序列化和反序列化数据而产生额外的延迟。

PJRT 直接使用本地设备,无需中间服务器进程。在默认配置中,PJRT 将为每个 TPU 芯片创建一个进程,或为每个 TPU 主机创建 4 个进程。有关 TPU 架构的更多信息,请参阅 Cloud TPU 文档

  • 对于受 . 产生的开销限制的工作负载,可以提高性能。

  • 在 XRT 下,服务器进程是唯一与 TPU 设备交互的进程,客户端进程无法直接访问 TPU 设备。当对单主机 TPU(例如 v3-8 或 v4-8)进行性能分析时,通常会看到 8 个设备跟踪(每个 TPU 内核一个)。使用 PJRT 时,每个进程都拥有一个芯片,并且来自该进程的性能分析将只显示 2 个 TPU 内核。

    • 出于同样的原因,在 TPU Pod 上对 XRT 进行性能分析是行不通的,因为服务器进程独立于用户的模型代码运行。PJRT 没有这种限制,因此可以对 TPU Pod 中每个进程的 2 个 TPU 内核进行性能分析。

  • PJRT 仅支持 TPU VM 架构,我们没有计划使用 PJRT 支持 TPU 节点架构。

  • 使用 PJRT 时,运行时配置要简单得多。xla_dist 不需要运行 TPU Pod 工作负载。相反,将您的代码复制到每个 TPU 主机上([gcloud compute tpus tpu-vm scp](https://cloud.google.com/sdk/gcloud/reference/alpha/compute/tpus/tpu-vm/scp))并在每个主机上并行运行代码(例如 [gcloud compute tpus tpu-vm ssh --workers=all --command="PJRT_DEVICE=TPU python run.py"](https://cloud.google.com/sdk/gcloud/reference/alpha/compute/tpus/tpu-vm/ssh)

  • xm.rendezvous 已使用 XLA 本地集合通信重新实现,以提高大型 TPU Pod 上的稳定性。有关更多详细信息,请参见以下内容。

TPU v2/v3 上的多线程

在 TPU v2 和 v3 上,分布式工作负载始终以多线程方式运行,因为每个 TPU 内核将两个 TPU 内核公开为设备,并且一次只能有一个进程打开一个 TPU 芯片。在其默认配置中,xmp.spawn 会自动生成尽可能多的进程(每个 TPU 主机 4 个),并为每个进程创建两个线程(每个 TPU 内核一个)。

注意:在 TPU v4 上,每个 TPU 芯片都被表示为一个 PyTorch 设备,因此分布式工作负载将在 4 个进程之间运行,每个进程只有一个线程。这与 XRT 的行为相同。

在大多数情况下,这不需要对现有代码进行重大更改。在大多数情况下,您需要进行的主要更改是模型初始化。由于 torch 的全局 RNG 在线程之间共享,即使在每个副本中将 torch.manual_seed 设置为相同的值,结果在线程之间和运行之间也会有所不同。为了使副本之间的参数一致,请使用 torch_xla.experimental.pjrt.broadcast_master_param 将一个副本的参数广播到所有其他副本,或从一个公共检查点加载每个副本的参数。

对 xm.rendezvous 的更改

PyTorch/XLA r2.0 中的新增功能

使用 XRT 时,工作程序 0 会运行一个网格主服务,并且所有工作程序上的所有进程都会通过 gRPC 连接到该服务。在实践中,我们发现,由于连接到工作程序 0 的入站连接数量过多,在一个拥有数千个芯片的 TPU Pod 上运行一个网格主进程是不稳定的。一个客户端进程超时可能会导致故障,并迫使整个工作负载重新启动。

因此,我们使用 XLA 本地集合通信重新实现了 xm.rendezvous,它在大型 TPU Pod 上更加稳定且经过充分测试。与 XRT 实现相比,这带来了两个新的约束

  • 由于有效负载必须成为 XLA 图的一部分,因此在数据传输之前和之后都会调用 xm.mark_step。在模型代码中间调用 xm.rendezvous 可能会强制执行不必要的编译。

  • 由于 XLA 不允许在部分工作程序上运行集合操作,因此所有工作程序都必须参与 rendezvous

如果您需要 xm.rendezvous 的旧行为(即在不更改 XLA 图的情况下通信数据和/或同步部分工作程序),请考虑使用 ``torch.distributed.barrier` <https://pytorch.ac.cn/docs/stable/distributed.html#torch.distributed.barrier>`_ 或 ``torch.distributed.all_gather_object` <https://pytorch.ac.cn/docs/stable/distributed.html#torch.distributed.all_gather_object>`_,以及 gloo 进程组。如果您还使用的是 xla torch.distributed 后端,则可以使用 torch.new_group 创建一个 gloo 子组。请参阅 PyTorch 文档中的 此示例。请牢记以下约束

  • torch.distributed 在 TPU v2/v3 上并不完全支持。仅实现了一部分具有 xla 后端的操作,并且 gloo 在多线程环境中可能无法按预期工作。

  • 在我们的实验中,gloo 在数千个 TPU 芯片上扩展性不佳,因此预计这种替代方案的可靠性不如在大型规模上使用 PJRT 的 xm.rendezvous

PJRT 和 torch.distributed

PyTorch/XLA r2.0 中的新增功能

当将 PJRT 与 torch.distributed[torch.nn.parallel.DistributedDataParallel](https://github.com/pytorch/xla/blob/master/docs/ddp.md) 一起使用时,我们强烈建议使用新的 xla:// init_method,它会通过查询运行时自动找到副本 ID、世界大小和主 IP。例如

import torch
import torch.distributed as dist
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
from torch_xla.experimental import pjrt

# Required for `xla://` init_method and `xla` backend
import torch_xla.distributed.xla_backend

def _all_gather(index: int):
  # No need to pass in `rank` or `world_size`
  dist.init_process_group('xla', init_method='xla://')

  t = torch.tensor([index], dtype=torch.int32, device=xm.xla_device())
  output = [torch.zeros_like(t) for _ in range(dist.get_world_size())]
  dist.all_gather(output, t)

  xm.mark_step()
  print(output)

if __name__ == '__main__':
  xmp.spawn(_all_gather)

注意:虽然在 TPU v4 上不需要 xla:// init_method,但我们仍然建议使用它。如果您使用的是 env://,则必须将 MASTER_ADDR 设置为拥有设备 0 的 IP 主机,该主机并非总是工作程序 0。xla:// init_method 会自动查找此 IP。

注意:对于 TPU v2/v3,您仍然需要导入 torch_xla.experimental.pjrt_backend,因为 torch.distributed 中的 TPU v2/v3 支持仍在试验阶段。

有关在 PyTorch/XLA 上使用 DistributedDataParallel 的更多信息,请参阅 TPU V4 上的 ``ddp.md` <./ddp.md>`_。要运行使用 DDP 和 PJRT 的示例,请在 TPU 上运行以下 示例脚本

PJRT_DEVICE=TPU python xla/test/test_train_mp_mnist.py --ddp --pjrt_distributed --fake_data --num_epochs 1

性能

TorchBench 显示,与 XRT 相比,使用 PJRT 时,跨任务的平均训练时间有所改善,在 TPU v4-8 上平均提高了 35% 以上。这种优势因任务和模型类型而异,从 0% 到 175% 不等。以下图表按任务列出了具体情况

PJRT vs XRT

新的 TPU 运行时

PyTorch/XLA r2.0 中的新增功能

PyTorch/XLA r2.0 版本引入了对 PJRT 插件 API 的支持,该 API 用于访问 libtpu 中基于 TFRT 的新 TPU 运行时。现在,当设置 PJRT_DEVICE=TPU 时,它将成为默认运行时。2.0 版本中仍将使用 PJRT_DEVICE=TPU_LEGACY 提供 1.13 中使用的基于 StreamExecutor 的传统 TPU 运行时,但它将在未来的版本中删除。如果您遇到仅在 TPU 上发生而不在 TPU_LEGACY 上发生的错误,请在 GitHub 上提交问题。

在大多数情况下,我们预计这两个运行时的性能会相似,但在某些情况下,新运行时的速度可能会快 30% 以上。以下图表按任务列出了具体情况

TFRT vs StreamExecutor

注意:此图表中显示的改进也包含在 PJRT 与 XRT 的比较中。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源