• 文档 >
  • 在 XLA 设备上使用 PyTorch
快捷方式

在 XLA 设备上使用 PyTorch

PyTorch 通过 torch_xla 软件包在 TPU 等 XLA 设备上运行。本文档介绍如何在这些设备上运行模型。

创建 XLA Tensor

PyTorch/XLA 为 PyTorch 添加了一种新的 xla 设备类型。这种设备类型的工作方式与其他 PyTorch 设备类型类似。例如,下面是创建并打印 XLA tensor 的方法

import torch
import torch_xla
import torch_xla.core.xla_model as xm

t = torch.randn(2, 2, device=xm.xla_device())
print(t.device)
print(t)

这段代码应该看起来很熟悉。PyTorch/XLA 使用与普通 PyTorch 相同的接口,并有一些附加功能。导入 torch_xla 初始化 PyTorch/XLA,而 xm.xla_device() 返回当前的 XLA 设备。这可能是 CPU 或 TPU,具体取决于您的环境。

XLA Tensor 是 PyTorch Tensor

PyTorch 操作可以在 XLA tensor 上执行,就像在 CPU 或 CUDA tensor 上一样。

例如,XLA tensor 可以相加

t0 = torch.randn(2, 2, device=xm.xla_device())
t1 = torch.randn(2, 2, device=xm.xla_device())
print(t0 + t1)

或者进行矩阵乘法

print(t0.mm(t1))

或者与神经网络模块一起使用

l_in = torch.randn(10, device=xm.xla_device())
linear = torch.nn.Linear(10, 20).to(xm.xla_device())
l_out = linear(l_in)
print(l_out)

与其他设备类型一样,XLA tensor 只能与同一设备上的其他 XLA tensor 一起使用。因此,像这样的代码

l_in = torch.randn(10, device=xm.xla_device())
linear = torch.nn.Linear(10, 20)
l_out = linear(l_in)
print(l_out)
# Input tensor is not an XLA tensor: torch.FloatTensor

会抛出错误,因为 torch.nn.Linear 模块位于 CPU 上。

在 XLA 设备上运行模型

构建新的 PyTorch 网络或将现有网络转换为在 XLA 设备上运行,只需要几行 XLA 特定的代码。以下代码片段突出显示了在单个设备上以及使用 XLA 多进程在多个设备上运行时需要修改的代码行。

在单个 XLA 设备上运行

以下代码片段展示了在单个 XLA 设备上训练网络的过程

import torch_xla.core.xla_model as xm

device = xm.xla_device()
model = MNIST().train().to(device)
loss_fn = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)

for data, target in train_loader:
  optimizer.zero_grad()
  data = data.to(device)
  target = target.to(device)
  output = model(data)
  loss = loss_fn(output, target)
  loss.backward()

  optimizer.step()
  xm.mark_step()

此代码片段突出显示了将模型切换到 XLA 上运行是多么容易。模型定义、数据加载器、优化器和训练循环可以在任何设备上工作。唯一的 XLA 特定代码是几行用于获取 XLA 设备和标记步骤的代码。在每个训练迭代结束时调用 xm.mark_step() 会使 XLA 执行其当前图并更新模型参数。有关 XLA 如何创建图和运行操作的更多信息,请参阅XLA Tensor 深度解析

使用多进程在多个 XLA 设备上运行

PyTorch/XLA 使在多个 XLA 设备上运行以加速训练变得容易。以下代码片段展示了如何操作

import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl

def _mp_fn(index):
  device = xm.xla_device()
  mp_device_loader = pl.MpDeviceLoader(train_loader, device)

  model = MNIST().train().to(device)
  loss_fn = nn.NLLLoss()
  optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)

  for data, target in mp_device_loader:
    optimizer.zero_grad()
    output = model(data)
    loss = loss_fn(output, target)
    loss.backward()
    xm.optimizer_step(optimizer)

if __name__ == '__main__':
  torch_xla.launch(_mp_fn, args=())

此多设备代码片段与之前的单个设备代码片段之间有三个区别。让我们一一介绍。

  • torch_xla.launch()

    • 创建每个运行 XLA 设备的进程。

    • 此函数是多线程 spawn 的包装器,也允许用户使用 torchrun 命令行运行脚本。每个进程只能访问分配给当前进程的设备。例如,在 TPU v4-8 上,将启动 4 个进程,每个进程拥有一个 TPU 设备。

    • 请注意,如果您在每个进程上打印 xm.xla_device(),您将在所有设备上看到 xla:0。这是因为每个进程只能看到一个设备。这并不意味着多进程没有正常工作。唯一的例外是在 TPU v2 和 TPU v3 上使用 PJRT 运行时,因为将有 #devices/2 个进程,并且每个进程将有 2 个线程(请查阅此文档了解更多详情)。

  • MpDeviceLoader

    • 将训练数据加载到每个设备上。

    • MpDeviceLoader 可以封装 torch dataloader。它可以将数据预加载到设备上,并将数据加载与设备执行重叠,以提高性能。

    • MpDeviceLoader 还会为您在每 batches_per_execution(默认为 1)个 yield 的批次后调用 xm.mark_step

  • xm.optimizer_step(optimizer)

    • 合并设备间的梯度并发出 XLA 设备步长计算。

    • 它基本上是 all_reduce_gradients + optimizer.step() + mark_step 的组合,并返回减少后的损失。

模型定义、优化器定义和训练循环保持不变。

注意: 需要重点注意的是,当使用多进程时,用户只能在 torch_xla.launch() 的目标函数内部(或调用栈中以 torch_xla.launch() 作为父函数的任何函数中)开始检索和访问 XLA 设备。

请参阅完整的多进程示例,了解更多关于使用多进程在多个 XLA 设备上训练网络的信息。

在 TPU Pod 上运行

不同加速器的多主机设置可能会有很大差异。本文档将讨论多主机训练中与设备无关的部分,并以 TPU + PJRT 运行时(目前在 1.13 和 2.x 版本中可用)为例进行说明。

在开始之前,请查看我们在此的用户指南,其中解释了一些 Google Cloud 基础知识,例如如何使用 gcloud 命令以及如何设置您的项目。您也可以在此查看所有 Cloud TPU 操作指南。本文档将重点介绍从 PyTorch/XLA 角度进行设置。

假设您有一个名为 train_mnist_xla.py 的文件,其中包含上面章节中的 mnist 示例。如果是单主机多设备训练,您可以 ssh 到 TPUVM 并运行类似命令

PJRT_DEVICE=TPU python3 train_mnist_xla.py

现在,为了在 TPU v4-16(它有 2 个主机,每个主机有 4 个 TPU 设备)上运行相同的模型,您需要: - 确保每个主机都可以访问训练脚本和训练数据。这通常通过使用 gcloud scp 命令或 gcloud ssh 命令将训练脚本复制到所有主机来完成。 - 同时在所有主机上运行相同的训练命令。

gcloud alpha compute tpus tpu-vm ssh $USER-pjrt --zone=$ZONE --project=$PROJECT --worker=all --command="PJRT_DEVICE=TPU python3 train_mnist_xla.py"

上面的 gcloud ssh 命令将 ssh 到 TPUVM Pod 中的所有主机并同时运行相同的命令。

注意: 您需要在 TPUVM 虚拟机外部运行上述 gcloud 命令。

模型代码和训练脚本对于多进程训练和多主机训练是相同的。PyTorch/XLA 和底层基础设施将确保每个设备都了解全局拓扑结构以及每个设备的本地和全局序号。跨设备通信将在所有设备之间进行,而不是仅限于本地设备。

有关 PJRT 运行时以及如何在 Pod 上运行的更多详情,请参阅此文档。有关 PyTorch/XLA 和 TPU Pod 的更多信息,以及在 TPU Pod 上使用 fakedata 运行 resnet50 的完整指南,请参阅此指南

XLA Tensor 深度解析

使用 XLA tensor 和设备只需要更改几行代码。但是,即使 XLA tensor 的行为很像 CPU 和 CUDA tensor,它们的内部结构是不同的。本节描述了 XLA tensor 的独特之处。

XLA Tensor 是惰性的

CPU 和 CUDA tensor 会立即或即时地启动操作。而 XLA tensor 则是惰性的。它们会在图中记录操作,直到需要结果时才执行。这样延迟执行使得 XLA 可以对其进行优化。例如,多个单独操作的图可能会被融合成一个优化的操作。

惰性执行通常对调用者是不可见的。PyTorch/XLA 会自动构建图,将其发送到 XLA 设备,并在 XLA 设备和 CPU 之间复制数据时进行同步。在进行优化器步骤时插入屏障会显式同步 CPU 和 XLA 设备。有关我们惰性 tensor 设计的更多信息,您可以阅读此论文

内存布局

XLA tensor 的内部数据表示对用户来说是不透明的。它们不暴露其存储,并且总是显示为连续的,这与 CPU 和 CUDA tensor 不同。这允许 XLA 调整 tensor 的内存布局以获得更好的性能。

在 XLA Tensor 和 CPU 之间移动

XLA tensor 可以从 CPU 移动到 XLA 设备,也可以从 XLA 设备移动到 CPU。如果移动的是视图,则它所查看的数据也会复制到另一个设备,并且视图关系不会保留。换句话说,一旦数据被复制到另一个设备,它与之前的设备或其上的任何 tensor 都没有关系。同样,根据您的代码操作方式,理解和适应这种转换可能很重要。

保存和加载 XLA Tensor

XLA tensor 在保存之前应该移动到 CPU,如下面代码片段所示

import torch
import torch_xla
import torch_xla.core.xla_model as xm

device = xm.xla_device()

t0 = torch.randn(2, 2, device=device)
t1 = torch.randn(2, 2, device=device)

tensors = (t0.cpu(), t1.cpu())

torch.save(tensors, 'tensors.pt')

tensors = torch.load('tensors.pt')

t0 = tensors[0].to(device)
t1 = tensors[1].to(device)

这样您就可以将加载的 tensor 放在任何可用设备上,而不仅仅是初始化它们的设备。

根据上面关于将 XLA tensor 移动到 CPU 的注意事项,在使用视图时必须小心。建议您不要保存视图,而是在 tensor 加载并移动到目标设备后重新创建它们。

提供了实用 API,用于在保存数据时负责将其事先移动到 CPU

import torch
import torch_xla
import torch_xla.core.xla_model as xm

xm.save(model.state_dict(), path)

在多个设备的情况下,上述 API 只会保存主设备序号 (0) 的数据。

在内存相对于模型参数大小有限的情况下,提供了一个 API,用于减少主机上的内存占用

import torch_xla.utils.serialization as xser

xser.save(model.state_dict(), path)

此 API 将 XLA tensor 一个接一个地流式传输到 CPU,减少了使用的主机内存量,但需要匹配的加载 API 进行恢复

import torch_xla.utils.serialization as xser

state_dict = xser.load(path)
model.load_state_dict(state_dict)

直接保存 XLA tensor 是可能的,但不推荐。XLA tensor 总是被加载回保存它们的设备,如果该设备不可用,加载将会失败。PyTorch/XLA 与所有 PyTorch 一样,正在积极开发中,此行为将来可能会改变。

编译缓存

XLA 编译器将跟踪的 HLO 转换为可在设备上运行的可执行文件。编译可能非常耗时,并且在 HLO 在多次执行中不变的情况下,编译结果可以持久化到磁盘以便重用,这显著减少了开发迭代时间。

请注意,如果 HLO 在执行之间发生变化,仍然会进行重新编译。

这目前是一个实验性的可选加入 API,必须在执行任何计算之前激活。初始化通过 initialize_cache API 完成

import torch_xla.runtime as xr
xr.initialize_cache('YOUR_CACHE_PATH', readonly=False)

这将在指定路径初始化一个持久化编译缓存。`readonly` 参数可用于控制 worker 是否可以写入缓存,这在将共享缓存挂载用于 SPMD 工作负载时非常有用。

如果您想在多进程训练中使用持久化编译缓存(使用 torch_xla.launchxmp.spawn),则应为不同的进程使用不同的路径。

def _mp_fn(index):
  # cache init needs to happens inside the mp_fn.
  xr.initialize_cache(f'/tmp/xla_cache_{index}', readonly=False)
  ....

if __name__ == '__main__':
  torch_xla.launch(_mp_fn, args=())

如果您无法访问 index,可以使用 xr.global_ordinal()。请查看此处的可运行示例。

进一步阅读

更多文档可在PyTorch/XLA 仓库中找到。在此提供了更多在 TPU 上运行网络的示例。

文档

查阅 PyTorch 的全面开发者文档

查看文档

教程

获取适合初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题解答

查看资源