PyTorch 在 XLA 设备上¶
PyTorch 在 XLA 设备(如 TPU)上运行,使用 torch_xla 包。本文档介绍了如何在这些设备上运行模型。
创建 XLA 张量¶
PyTorch/XLA 向 PyTorch 添加了一种新的 xla
设备类型。此设备类型的工作方式与其他 PyTorch 设备类型完全相同。例如,以下是如何创建和打印 XLA 张量
import torch
import torch_xla
import torch_xla.core.xla_model as xm
t = torch.randn(2, 2, device=xm.xla_device())
print(t.device)
print(t)
此代码看起来应该很熟悉。PyTorch/XLA 使用与常规 PyTorch 相同的接口,并添加了一些内容。导入 torch_xla
会初始化 PyTorch/XLA,而 xm.xla_device()
会返回当前的 XLA 设备。这可能是 CPU 或 TPU,具体取决于您的环境。
XLA 张量是 PyTorch 张量¶
可以像 CPU 或 CUDA 张量一样对 XLA 张量执行 PyTorch 操作。
例如,可以将 XLA 张量相加
t0 = torch.randn(2, 2, device=xm.xla_device())
t1 = torch.randn(2, 2, device=xm.xla_device())
print(t0 + t1)
或矩阵相乘
print(t0.mm(t1))
或与神经网络模块一起使用
l_in = torch.randn(10, device=xm.xla_device())
linear = torch.nn.Linear(10, 20).to(xm.xla_device())
l_out = linear(l_in)
print(l_out)
与其他设备类型一样,XLA 张量仅适用于同一设备上的其他 XLA 张量。因此,类似下面的代码
l_in = torch.randn(10, device=xm.xla_device())
linear = torch.nn.Linear(10, 20)
l_out = linear(l_in)
print(l_out)
# Input tensor is not an XLA tensor: torch.FloatTensor
将引发错误,因为 torch.nn.Linear
模块位于 CPU 上。
在 XLA 设备上运行模型¶
构建新的 PyTorch 网络或转换现有网络以在 XLA 设备上运行,仅需几行特定于 XLA 的代码。以下代码片段重点介绍了在单个设备上以及使用 XLA 多处理在多个设备上运行时的这些行。
在单个 XLA 设备上运行¶
以下代码片段显示了在单个 XLA 设备上训练的网络
import torch_xla.core.xla_model as xm
device = xm.xla_device()
model = MNIST().train().to(device)
loss_fn = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
for data, target in train_loader:
optimizer.zero_grad()
data = data.to(device)
target = target.to(device)
output = model(data)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
xm.mark_step()
此代码片段重点介绍了将模型切换到在 XLA 上运行是多么容易。模型定义、数据加载器、优化器和训练循环可以在任何设备上工作。唯一的 XLA 特定代码是几行获取 XLA 设备并标记步骤的代码。在每次训练迭代结束时调用 xm.mark_step()
会导致 XLA 执行其当前图并更新模型的参数。有关 XLA 如何创建图和运行操作的更多信息,请参阅 XLA 张量深入探讨。
使用多处理在多个 XLA 设备上运行¶
PyTorch/XLA 可以通过在多个 XLA 设备上运行来轻松加速训练。以下代码片段显示了如何操作
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
def _mp_fn(index):
device = xm.xla_device()
mp_device_loader = pl.MpDeviceLoader(train_loader, device)
model = MNIST().train().to(device)
loss_fn = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
for data, target in mp_device_loader:
optimizer.zero_grad()
output = model(data)
loss = loss_fn(output, target)
loss.backward()
xm.optimizer_step(optimizer)
if __name__ == '__main__':
torch_xla.launch(_mp_fn, args=())
此多设备代码片段与之前的单设备代码片段之间存在三个差异。让我们逐一介绍。
torch_xla.launch()
创建每个运行一个 XLA 设备的进程。
此函数是多线程 spawn 的包装器,允许用户使用 torchrun 命令行运行脚本。每个进程将只能访问分配给当前进程的设备。例如,在 TPU v4-8 上,将生成 4 个进程,每个进程将拥有一个 TPU 设备。
请注意,如果您在每个进程上打印
xm.xla_device()
,您将在所有设备上看到xla:0
。这是因为每个进程只能看到一个设备。这并不意味着多进程不起作用。唯一的执行是使用 TPU v2 和 TPU v3 上的 PJRT 运行时,因为将有#devices/2
个进程,并且每个进程将有 2 个线程(有关更多详细信息,请查看此 文档)。
MpDeviceLoader
将训练数据加载到每个设备上。
MpDeviceLoader
可以包装在 torch 数据加载器上。它可以将数据预加载到设备并使数据加载与设备执行重叠,从而提高性能。MpDeviceLoader
还会为您在每次 yieldbatches_per_execution
(默认为 1)批次时调用xm.mark_step
。
xm.optimizer_step(optimizer)
整合设备之间的梯度并发出 XLA 设备步骤计算。
它几乎是一个
all_reduce_gradients
+optimizer.step()
+mark_step
,并返回正在减少的损失。
模型定义、优化器定义和训练循环保持不变。
注意: 重要的是要注意,当使用多处理时,用户只能从
torch_xla.launch()
的目标函数(或任何在调用堆栈中将torch_xla.launch()
作为父级的函数)中开始检索和访问 XLA 设备。
有关使用多处理在多个 XLA 设备上训练网络的更多信息,请参阅 完整的 multiprocessing 示例。
在 TPU Pod 上运行¶
不同加速器的多主机设置可能非常不同。本文档将讨论多主机训练的设备无关位,并将 TPU + PJRT 运行时(当前在 1.13 和 2.x 版本上可用)用作示例。
在开始之前,请查看我们的用户指南 此处,其中将解释一些 Google Cloud 基础知识,例如如何使用 gcloud
命令以及如何设置项目。您还可以查看 此处 获取所有 Cloud TPU Howto。本文档将重点介绍 PyTorch/XLA 的设置视角。
假设您在 train_mnist_xla.py
中有上面部分中的 mnist 示例。如果是单主机多设备训练,您将 ssh 到 TPUVM 并运行如下命令
PJRT_DEVICE=TPU python3 train_mnist_xla.py
现在,为了在 TPU v4-16(具有 2 个主机,每个主机具有 4 个 TPU 设备)上运行相同的模型,您需要 - 确保每个主机都可以访问训练脚本和训练数据。这通常通过使用 gcloud scp
命令或 gcloud ssh
命令将训练脚本复制到所有主机来完成。 - 同时在所有主机上运行相同的训练命令。
gcloud alpha compute tpus tpu-vm ssh $USER-pjrt --zone=$ZONE --project=$PROJECT --worker=all --command="PJRT_DEVICE=TPU python3 train_mnist_xla.py"
上面的 gcloud ssh
命令将 ssh 连接到 TPUVM Pod 中的所有主机,并同时运行相同的命令。
注意: 您需要在 TPUVM vm 外部运行上面的
gcloud
命令。
模型代码和训练脚本对于多进程训练和多主机训练是相同的。PyTorch/XLA 和底层基础架构将确保每个设备都了解全局拓扑以及每个设备的本地和全局序号。跨设备通信将在所有设备而不是本地设备之间发生。
有关 PJRT 运行时以及如何在 pod 上运行它的更多详细信息,请参阅此 文档。有关 PyTorch/XLA 和 TPU pod 的更多信息,以及在 TPU pod 上使用 fakedata 运行 resnet50 的完整指南,请参阅此指南。
XLA 张量深入探讨¶
使用 XLA 张量和设备只需要更改几行代码。但是,即使 XLA 张量的行为与 CPU 和 CUDA 张量非常相似,但它们的内部结构却不同。本节介绍使 XLA 张量独一无二的原因。
XLA 张量是惰性的¶
CPU 和 CUDA 张量会立即或急切地启动操作。另一方面,XLA 张量是惰性的。它们将操作记录在图中,直到需要结果。像这样延迟执行可以让 XLA 对其进行优化。例如,可以将多个单独操作的图融合为单个优化操作。
惰性执行通常对调用者不可见。PyTorch/XLA 会自动构建图,将其发送到 XLA 设备,并在 XLA 设备和 CPU 之间复制数据时进行同步。在执行优化器步骤时插入屏障会显式同步 CPU 和 XLA 设备。有关我们的惰性张量设计的更多信息,您可以阅读 本文。
内存布局¶
XLA 张量的内部数据表示对用户是不透明的。它们不公开其存储,并且始终显示为连续的,这与 CPU 和 CUDA 张量不同。这允许 XLA 调整张量的内存布局以获得更好的性能。
在 CPU 和 XLA 设备之间移动 XLA 张量¶
XLA 张量可以从 CPU 移动到 XLA 设备,也可以从 XLA 设备移动到 CPU。如果移动视图,则其正在查看的数据也会复制到另一设备,并且视图关系不会保留。换句话说,一旦数据复制到另一设备,它就与其之前的设备或其上的任何张量没有任何关系。同样,根据您的代码运行方式,理解和适应此转换可能很重要。
保存和加载 XLA 张量¶
XLA 张量应在保存之前移动到 CPU,如下面的代码片段所示
import torch
import torch_xla
import torch_xla.core.xla_model as xm
device = xm.xla_device()
t0 = torch.randn(2, 2, device=device)
t1 = torch.randn(2, 2, device=device)
tensors = (t0.cpu(), t1.cpu())
torch.save(tensors, 'tensors.pt')
tensors = torch.load('tensors.pt')
t0 = tensors[0].to(device)
t1 = tensors[1].to(device)
这使您可以将加载的张量放在任何可用设备上,而不仅仅是初始化它们的设备上。
根据上面关于将 XLA 张量移动到 CPU 的说明,在使用视图时必须小心。建议您在加载张量并将其移动到目标设备后重新创建视图,而不是保存视图。
提供了一个实用程序 API,用于通过事先将其移动到 CPU 来保存数据
import torch
import torch_xla
import torch_xla.core.xla_model as xm
xm.save(model.state_dict(), path)
在多设备情况下,上面的 API 将仅保存主设备序号 (0) 的数据。
在内存与模型参数大小相比有限的情况下,提供了一个 API,用于减少主机上的内存占用
import torch_xla.utils.serialization as xser
xser.save(model.state_dict(), path)
此 API 一次将 XLA 张量流式传输到 CPU,从而减少了使用的主机内存量,但它需要匹配的加载 API 才能恢复
import torch_xla.utils.serialization as xser
state_dict = xser.load(path)
model.load_state_dict(state_dict)
可以直接保存 XLA 张量,但不建议这样做。XLA 张量始终加载回保存它们的设备,如果该设备不可用,则加载将失败。PyTorch/XLA 与所有 PyTorch 一样,正在积极开发中,此行为将来可能会更改。
编译缓存¶
XLA 编译器将跟踪的 HLO 转换为在设备上运行的可执行文件。编译可能很耗时,并且在 HLO 在执行之间没有更改的情况下,可以将编译结果持久保存到磁盘以供重用,从而显着减少开发迭代时间。
请注意,如果 HLO 在执行之间发生更改,则仍会发生重新编译。
这目前是一个实验性的选择加入 API,必须在执行任何计算之前激活。初始化通过 initialize_cache
API 完成
import torch_xla.runtime as xr
xr.initialize_cache('YOUR_CACHE_PATH', readonly=False)
这将在指定路径初始化持久编译缓存。readonly
参数可用于控制工作进程是否能够写入缓存,这在共享缓存挂载用于 SPMD 工作负载时可能很有用。
如果您想在多进程训练(使用 torch_xla.launch
或 xmp.spawn
)中使用持久编译缓存,则应为不同的进程使用不同的路径。
def _mp_fn(index):
# cache init needs to happens inside the mp_fn.
xr.initialize_cache(f'/tmp/xla_cache_{index}', readonly=False)
....
if __name__ == '__main__':
torch_xla.launch(_mp_fn, args=())
如果您无权访问 index
,则可以使用 xr.global_ordinal()
。在 此处 查看可运行的示例。
进一步阅读¶
其他文档可在 PyTorch/XLA 代码库中找到。有关在 TPU 上运行网络的更多示例,请访问 此处。