XLA 设备上的 PyTorch¶
PyTorch 可以使用 torch_xla 软件包 在 XLA 设备(如 TPU)上运行。本文档介绍了如何在这些设备上运行模型。
创建 XLA 张量¶
PyTorch/XLA 为 PyTorch 添加了一个新的 xla
设备类型。此设备类型的工作方式与其他 PyTorch 设备类型相同。例如,以下是创建和打印 XLA 张量的方法
import torch
import torch_xla
import torch_xla.core.xla_model as xm
t = torch.randn(2, 2, device=xm.xla_device())
print(t.device)
print(t)
此代码应该看起来很熟悉。PyTorch/XLA 使用与常规 PyTorch 相同的接口,并进行了一些添加。导入 torch_xla
会初始化 PyTorch/XLA,而 xm.xla_device()
会返回当前的 XLA 设备。根据您的环境,这可能是 CPU 或 TPU。
XLA 张量是 PyTorch 张量¶
PyTorch 操作可以在 XLA 张量上执行,就像 CPU 或 CUDA 张量一样。
例如,可以将 XLA 张量相加
t0 = torch.randn(2, 2, device=xm.xla_device())
t1 = torch.randn(2, 2, device=xm.xla_device())
print(t0 + t1)
或进行矩阵乘法
print(t0.mm(t1))
或与神经网络模块一起使用
l_in = torch.randn(10, device=xm.xla_device())
linear = torch.nn.Linear(10, 20).to(xm.xla_device())
l_out = linear(l_in)
print(l_out)
与其他设备类型一样,XLA 张量只能与同一设备上的其他 XLA 张量一起使用。所以像
l_in = torch.randn(10, device=xm.xla_device())
linear = torch.nn.Linear(10, 20)
l_out = linear(l_in)
print(l_out)
# Input tensor is not an XLA tensor: torch.FloatTensor
这样的代码会引发错误,因为 torch.nn.Linear
模块在 CPU 上。
在 XLA 设备上运行模型¶
构建新的 PyTorch 网络或转换现有网络以在 XLA 设备上运行只需要几行特定于 XLA 的代码。以下代码片段突出显示了在单个设备上运行以及使用 XLA 多处理在多个设备上运行时出现的这些行。
在单个 XLA 设备上运行¶
以下代码片段显示了在单个 XLA 设备上训练网络
import torch_xla.core.xla_model as xm
device = xm.xla_device()
model = MNIST().train().to(device)
loss_fn = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
for data, target in train_loader:
optimizer.zero_grad()
data = data.to(device)
target = target.to(device)
output = model(data)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
xm.mark_step()
此代码片段突出显示了将模型切换到在 XLA 上运行是多么容易。模型定义、数据加载器、优化器和训练循环可以在任何设备上工作。唯一特定于 XLA 的代码是获取 XLA 设备并标记步骤的几行代码。在每次训练迭代结束时调用 xm.mark_step()
会导致 XLA 执行其当前图并更新模型的参数。有关 XLA 如何创建图和运行操作的更多信息,请参阅 XLA 张量深度解析。
使用多处理在多个 XLA 设备上运行¶
PyTorch/XLA 可以通过在多个 XLA 设备上运行来轻松加速训练。以下代码片段显示了如何实现
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp
def _mp_fn(index):
device = xm.xla_device()
mp_device_loader = pl.MpDeviceLoader(train_loader, device)
model = MNIST().train().to(device)
loss_fn = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
for data, target in mp_device_loader:
optimizer.zero_grad()
output = model(data)
loss = loss_fn(output, target)
loss.backward()
xm.optimizer_step(optimizer)
if __name__ == '__main__':
xmp.spawn(_mp_fn, args=())
此多设备代码片段与之前的单设备代码片段有三个区别。让我们逐一介绍一下。
xmp.spawn()
创建每个都运行 XLA 设备的进程。
每个进程只能访问分配给当前进程的设备。例如,在 TPU v4-8 上,将生成 4 个进程,每个进程将拥有一个 TPU 设备。
请注意,如果您在每个进程上打印
xm.xla_device()
,您将在所有设备上看到xla:0
。这是因为每个进程只能看到一个设备。这并不意味着多进程不起作用。只有在 TPU v2 和 TPU v3 上使用 PJRT 运行时才会执行,因为将有#devices/2
个进程,每个进程将有 2 个线程(有关更多详细信息,请查看此 文档)。
MpDeviceLoader
将训练数据加载到每个设备。
MpDeviceLoader
可以包装在 PyTorch 数据加载器上。它可以将数据预加载到设备,并将数据加载与设备执行重叠以提高性能。MpDeviceLoader
还会在每batches_per_execution
(默认为 1)个批次被生成时调用xm.mark_step
。
xm.optimizer_step(optimizer)
整合设备之间的梯度并发出 XLA 设备步骤计算。
它几乎等同于
all_reduce_gradients
+optimizer.step()
+mark_step
,并返回被减少的损失。
模型定义、优化器定义和训练循环保持不变。
**注意:** 需要注意的是,当使用多进程时,用户只能从
xmp.spawn()
的目标函数(或任何在调用堆栈中以xmp.spawn()
为父级的函数)内部开始检索和访问 XLA 设备。
有关在多个 XLA 设备上使用多进程训练网络的更多信息,请参阅 完整的 multiprocessing 示例。
在 TPU Pod 上运行¶
不同加速器的多主机设置可能非常不同。本文档将讨论多主机训练中与设备无关的部分,并将以 TPU + PJRT 运行时(当前在 1.13 和 2.x 版本中可用)为例。
在开始之前,请先查看 此处 的用户指南,其中将解释一些 Google Cloud 基础知识,例如如何使用 gcloud
命令以及如何设置项目。您还可以查看 此处 的所有 Cloud TPU 操作指南。本文档将重点介绍 PyTorch/XLA 在设置方面的视角。
假设您在 train_mnist_xla.py
中有上述部分的 mnist 示例。如果是单主机多设备训练,您需要 ssh 到 TPUVM 并运行如下命令
PJRT_DEVICE=TPU python3 train_mnist_xla.py
现在,为了在 TPU v4-16(有 2 个主机,每个主机有 4 个 TPU 设备)上运行相同的模型,您需要
确保每个主机都可以访问训练脚本和训练数据。这通常是通过使用
gcloud scp
命令或gcloud ssh
命令将训练脚本复制到所有主机来完成的。在所有主机上同时运行相同的训练命令。
gcloud alpha compute tpus tpu-vm ssh $USER-pjrt --zone=$ZONE --project=$PROJECT --worker=all --command="PJRT_DEVICE=TPU python3 train_mnist_xla.py"
上面的 gcloud ssh
命令将 ssh 到 TPUVM Pod 中的所有主机,并同时运行相同的命令。
**注意:** 您需要在 TPUVM 虚拟机外部运行上述
gcloud
命令。
多进程训练和多主机训练的模型代码和训练脚本相同。PyTorch/XLA 和底层基础设施将确保每个设备都知道全局拓扑以及每个设备的本地和全局序号。跨设备通信将在所有设备之间进行,而不是仅在本地设备之间进行。
有关 PJRT 运行时以及如何在 Pod 上运行它的更多详细信息,请参阅此 文档。有关 PyTorch/XLA 和 TPU Pod 的更多信息,以及在 TPU Pod 上运行带有虚假数据的 resnet50 的完整指南,请参阅此 指南。
XLA 张量深度解析¶
使用 XLA 张量和设备只需要更改几行代码。但是,即使 XLA 张量的行为与 CPU 和 CUDA 张量非常相似,它们的内部结构也不同。本节介绍 XLA 张量的独特之处。
XLA 张量是惰性的¶
CPU 和 CUDA 张量会立即或急切地启动操作。另一方面,XLA 张量是惰性的。它们会将操作记录在图中,直到需要结果为止。像这样推迟执行可以使 XLA 对其进行优化。例如,可以将多个独立操作的图融合到一个优化的操作中。
惰性执行通常对调用者不可见。PyTorch/XLA 会自动构建图,将它们发送到 XLA 设备,并在 XLA 设备和 CPU 之间复制数据时进行同步。在执行优化器步骤时插入屏障会显式同步 CPU 和 XLA 设备。有关我们惰性张量设计的更多信息,您可以阅读 这篇论文。
XLA 张量和 bFloat16¶
PyTorch/XLA 在 TPU 上运行时可以使用 bfloat16 数据类型。实际上,PyTorch/XLA 在 TPU 上处理浮点类型(torch.float
和 torch.double
)的方式不同。此行为由 XLA_USE_BF16
和 XLA_DOWNCAST_BF16
环境变量控制
默认情况下,
torch.float
和torch.double
在 TPU 上都是torch.float
。如果设置了
XLA_USE_BF16
,则torch.float
和torch.double
在 TPU 上都是bfloat16
。如果设置了
XLA_DOWNCAST_BF16
,则torch.float
在 TPU 上是bfloat16
,torch.double
在 TPU 上是float32
。如果 PyTorch 张量的 data type 为
torch.bfloat16
,则它将直接映射到 TPUbfloat16
(XLABF16
原语类型)。
开发人员应该注意,*TPU 上的 XLA 张量始终会报告其 PyTorch 数据类型*,而不管它们实际使用的数据类型是什么。此转换是自动且不透明的。如果将 TPU 上的 XLA 张量移回 CPU,则会将其从实际数据类型转换为 PyTorch 数据类型。根据您的代码如何运行,由处理单元类型触发的这种转换可能很重要。
内存布局¶
XLA 张量的内部数据表示对用户是不透明的。与 CPU 和 CUDA 张量不同,它们不会公开其存储,并且始终看起来是连续的。这使得 XLA 可以调整张量的内存布局以获得更好的性能。
在 CPU 和 XLA 设备之间移动 XLA 张量¶
可以将 XLA 张量从 CPU 移动到 XLA 设备,也可以从 XLA 设备移动到 CPU。如果移动视图,则其查看的数据也会复制到另一个设备,并且视图关系不会保留。换句话说,一旦数据被复制到另一个设备,它就与其先前的设备或其上的任何张量没有关系。同样,根据您的代码如何运行,理解和适应这种转换可能很重要。
保存和加载 XLA 张量¶
在保存 XLA 张量之前,应将其移动到 CPU,如以下代码段所示
import torch
import torch_xla
import torch_xla.core.xla_model as xm
device = xm.xla_device()
t0 = torch.randn(2, 2, device=device)
t1 = torch.randn(2, 2, device=device)
tensors = (t0.cpu(), t1.cpu())
torch.save(tensors, 'tensors.pt')
tensors = torch.load('tensors.pt')
t0 = tensors[0].to(device)
t1 = tensors[1].to(device)
这使您可以将加载的张量放在任何可用的设备上,而不仅仅是初始化它们时的设备。
根据上面关于将 XLA 张量移动到 CPU 的说明,在使用视图时必须小心。建议您在加载张量并将其移动到目标设备后重新创建视图,而不是保存视图。
提供了一个实用程序 API,可以通过先将数据移动到 CPU 来保存数据
import torch
import torch_xla
import torch_xla.core.xla_model as xm
xm.save(model.state_dict(), path)
如果有多个设备,则上述 API 将仅保存主设备序号(0)的数据。
如果内存与模型参数的大小相比有限,则提供了一个 API 来减少主机上的内存占用
import torch_xla.utils.serialization as xser
xser.save(model.state_dict(), path)
此 API 一次将一个 XLA 张量流式传输到 CPU,从而减少了主机内存的使用量,但需要匹配的加载 API 才能恢复
import torch_xla.utils.serialization as xser
state_dict = xser.load(path)
model.load_state_dict(state_dict)
可以直接保存 XLA 张量,但不建议这样做。XLA 张量始终会加载回保存它们的设备,如果该设备不可用,则加载将失败。与所有 PyTorch 一样,PyTorch/XLA 正在积极开发中,此行为将来可能会发生变化。
编译缓存¶
XLA 编译器将跟踪的 HLO 转换为在设备上运行的可执行文件。编译可能非常耗时,如果 HLO 在执行过程中没有变化,可以将编译结果持久化到磁盘以供重复使用,从而显著减少开发迭代时间。
请注意,如果 HLO 在执行之间发生变化,则仍将进行重新编译。
这目前是一个实验性的可选 API,必须在执行任何计算之前激活。初始化是通过 initialize_cache
API 完成的
import torch_xla.runtime as xr
xr.initialize_cache('YOUR_CACHE_PATH', readonly=False)
这将在指定的路径初始化一个持久化的编译缓存。readonly
参数可用于控制工作线程是否能够写入缓存,这在为 SPMD 工作负载使用共享缓存挂载时非常有用。
进一步阅读¶
更多文档可在 PyTorch/XLA 代码仓库 中找到。更多在 TPU 上运行网络的示例可在 此处 找到。
PyTorch/XLA API¶
xla_model¶
- torch_xla.core.xla_model.xla_device(n=None, devkind=None)[源代码]¶
返回 XLA 设备的给定实例。
- 参数
n (python:int, 可选) – 要返回的特定实例(序号)。如果指定,则返回特定的 XLA 设备实例。否则,将返回 devkind 的第一个设备。
devkind (字符串..., 可选) – 如果指定,则为设备类型,例如 TPU、CUDA、CPU 或自定义 PJRT 设备。已弃用。
- 返回值
具有请求实例的 torch.device。
- torch_xla.core.xla_model.get_xla_supported_devices(devkind=None, max_devices=None)[源代码]¶
返回给定类型支持的设备列表。
- 参数
devkind (字符串..., 可选) – 如果指定,则为设备类型,例如 TPU、CUDA、CPU 或自定义 PJRT 设备的名称。
max_devices (python:int, 可选) – 要返回的该类型设备的最大数量。
- 返回值
0'、'xla:1',…]
- 返回类型
设备字符串列表,例如 ['xla
- torch_xla.core.xla_model.xla_device_hw(device)[源代码]¶
返回给定设备的硬件类型。
- 参数
device (字符串 或 torch.device) – 将映射到真实设备的 xla 设备。
- 返回值
给定设备的硬件类型的字符串表示形式。
- torch_xla.core.xla_model.get_ordinal(defval=0)[源代码]¶
检索当前线程的副本序号。
序号范围从 0 到 xrt_world_size() 减 1。
- 参数
defval (python:int, 可选) – 如果没有可用的副本信息,则返回的默认值。对于运行时将被忽略。默认值:0
- 返回值
当前线程的副本序号。
- torch_xla.core.xla_model.get_local_ordinal(defval=0)[源代码]¶
检索当前线程的副本本地序号。
本地序号范围从 0 到本地设备数量减 1。
- 参数
defval (python:int, 可选) – 如果没有可用的副本信息,则返回的默认值。对于运行时将被忽略。默认值:0
- 返回值
当前线程的副本本地序号。
- torch_xla.core.xla_model.is_master_ordinal(local=True)[源代码]¶
检查当前进程是否为主序号 (0)。
- 参数
local (布尔值) – 是否应检查本地或全局主序号。如果是多主机复制,则只有一个全局主序号(主机 0,设备 0),而有 NUM_HOSTS 个本地主序号。默认值:True
- 返回值
指示当前进程是否为主序号的布尔值。
- torch_xla.core.xla_model.xrt_world_size(defval=1)[源代码]¶
检索参与复制的设备数量。
- 参数
defval (python:int, 可选) – 如果没有可用的副本信息,则返回的默认值。默认值:1
- 返回值
参与复制的设备数量。
- torch_xla.core.xla_model.all_reduce(reduce_type, inputs, scale=1.0, groups=None, pin_layout=True)[源代码]¶
对输入张量执行就地归约操作。
- 参数
reduce_type (字符串) –
xm.REDUCE_SUM
、xm.REDUCE_MUL
、xm.REDUCE_AND
、xm.REDUCE_OR
、xm.REDUCE_MIN
和xm.REDUCE_MAX
之一。inputs – 要对其执行全部归约操作的单个 torch.Tensor 或 torch.Tensor 列表。
scale (python:float) – 在归约后应用的默认缩放值。默认值:1.0
groups (列表, 可选) –
一个列表列表,表示 all_reduce() 操作的副本组。示例:[[0, 1, 2, 3], [4, 5, 6, 7]]
定义了两个组,一个包含 [0, 1, 2, 3] 副本,另一个包含 [4, 5, 6, 7] 副本。如果为 None,则只有一个组,其中包含所有副本。
pin_layout (布尔值, 可选) – 是否为此通信操作固定布局。布局固定可以防止参与通信的每个进程的程序略有不同时出现潜在的数据损坏,但它可能会导致某些 xla 编译失败。当您看到类似“HloModule has a mix of layout constrained”的错误消息时,请取消固定布局。
- 返回值
如果传递单个 torch.Tensor,则返回值为包含归约值(跨副本)的 torch.Tensor。如果传递列表/元组,则此函数对输入张量执行就地全部归约操作,并返回列表/元组本身。
- torch_xla.core.xla_model.all_gather(value, dim=0, groups=None, output=None, pin_layout=True)[源代码]¶
沿给定维度执行全部收集操作。
- 参数
value (torch.Tensor) – 输入张量。
dim (python:int) – 收集维度。默认值:0
groups (列表, 可选) –
一个列表列表,表示 all_gather() 操作的副本组。示例:[[0, 1, 2, 3], [4, 5, 6, 7]]
定义了两个组,一个包含 [0, 1, 2, 3] 副本,另一个包含 [4, 5, 6, 7] 副本。如果为 None,则只有一个组,其中包含所有副本。
output (torch.Tensor) – 可选的输出张量。
pin_layout (布尔值, 可选) – 是否为此通信操作固定布局。布局固定可以防止参与通信的每个进程的程序略有不同时出现潜在的数据损坏,但它可能会导致某些 xla 编译失败。当您看到类似“HloModule has a mix of layout constrained”的错误消息时,请取消固定布局。
- 返回值
在
dim
维度上具有来自参与副本的所有值的张量。
- torch_xla.core.xla_model.all_to_all(value, split_dimension, concat_dimension, split_count, groups=None, pin_layout=True)[源代码]¶
对输入张量执行 XLA AllToAll() 操作。
参见:https://tensorflowcn.cn/xla/operation_semantics#alltoall
- 参数
value (torch.Tensor) – 输入张量。
split_dimension (python:int) – 应该在其上进行拆分的维度。
concat_dimension (python:int) – 应该在其上进行连接的维度。
split_count (python:int) – 拆分计数。
groups (列表, 可选) –
一个列表列表,表示 all_reduce() 操作的副本组。示例:[[0, 1, 2, 3], [4, 5, 6, 7]]
定义了两个组,一个包含 [0, 1, 2, 3] 副本,另一个包含 [4, 5, 6, 7] 副本。如果为 None,则只有一个组,其中包含所有副本。
pin_layout (布尔值, 可选) – 是否为此通信操作固定布局。布局固定可以防止参与通信的每个进程的程序略有不同时出现潜在的数据损坏,但它可能会导致某些 xla 编译失败。当您看到类似“HloModule has a mix of layout constrained”的错误消息时,请取消固定布局。
- 返回值
all_to_all() 操作的结果 torch.Tensor。
- torch_xla.core.xla_model.add_step_closure(closure, args=(), run_async=False)[源代码]¶
将闭包添加到要在步骤结束时运行的闭包列表中。
在模型训练过程中,很多时候需要打印/报告(打印到控制台,发布到 TensorBoard 等)信息,这需要检查中间张量的值。在模型代码的不同点检查不同张量的值需要多次执行,并且通常会导致性能问题。添加 step 闭包将确保它在 barrier 之后运行,届时所有活动张量都将已实例化为设备数据。活动张量将包括闭包参数捕获的那些张量。因此,使用 <cite>add_step_closure()</cite> 将确保即使在多个闭包排队并需要检查多个张量时,也只会执行一次。Step 闭包将按照排队顺序依次运行。请注意,即使使用此 API 优化了执行,也建议每 N 步限制一次打印/报告事件。
- 参数
closure (callable) – 要调用的函数。
args (tuple) – 要传递给闭包的参数。
run_async – 如果为 True,则异步运行闭包。
- <span class="sig-prename descclassname"><span class="pre">torch_xla.core.xla_model.</span></span><span class="sig-name descname"><span class="pre">wait_device_ops</span></span><span class="sig-paren">(<em class="sig-param"><span class="n"><span class="pre">devices</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">[]</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="_modules/torch_xla/core/xla_model.html#wait_device_ops"><span class="viewcode-link"><span class="pre">[源代码]</span></span></a><a class="headerlink" href="#torch_xla.core.xla_model.wait_device_ops" title="永久链接到此定义">¶</a>
等待给定设备上的所有异步操作完成。
- 参数
devices (string..., 可选) – 需要等待其异步操作完成的设备。如果为空,则将等待所有本地设备。
- <span class="sig-prename descclassname"><span class="pre">torch_xla.core.xla_model.</span></span><span class="sig-name descname"><span class="pre">optimizer_step</span></span><span class="sig-paren">(<em class="sig-param"><span class="n"><span class="pre">optimizer</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">barrier</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">False</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">optimizer_args</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">{}</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">groups</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">pin_layout</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">True</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="_modules/torch_xla/core/xla_model.html#optimizer_step"><span class="viewcode-link"><span class="pre">[源代码]</span></span></a><a class="headerlink" href="#torch_xla.core.xla_model.optimizer_step" title="永久链接到此定义">¶</a>
运行提供的优化器步骤并发出 XLA 设备步骤计算。
- 参数
optimizer (<code class="xref py py-class docutils literal notranslate"><span class="pre">torch.Optimizer</span></code>) – 需要调用其 <cite>step()</cite> 函数的 <cite>torch.Optimizer</cite> 实例。将使用 <cite>optimizer_args</cite> 命名参数调用 <cite>step()</cite> 函数。
barrier (bool, 可选) – 是否应在此 API 中发出 XLA 张量 barrier。如果使用 PyTorch XLA <cite>ParallelLoader</cite> 或 <cite>DataParallel</cite> 支持,则不需要这样做,因为 barrier 将由 XLA 数据加载器迭代器 <cite>next()</cite> 调用发出。默认值:False
optimizer_args (dict, 可选) – <cite>optimizer.step()</cite> 调用的命名参数字典。
groups (列表, 可选) –
一个列表列表,表示 all_reduce() 操作的副本组。示例:[[0, 1, 2, 3], [4, 5, 6, 7]]
定义了两个组,一个包含 [0, 1, 2, 3] 副本,另一个包含 [4, 5, 6, 7] 副本。如果为 None,则只有一个组,其中包含所有副本。
pin_layout (bool, 可选) – 在减少梯度时是否固定布局。有关详细信息,请参阅 <cite>xm.all_reduce</cite>。
- 返回值
<cite>optimizer.step()</cite> 调用返回的相同值。
- <span class="sig-prename descclassname"><span class="pre">torch_xla.core.xla_model.</span></span><span class="sig-name descname"><span class="pre">save</span></span><span class="sig-paren">(<em class="sig-param"><span class="n"><span class="pre">data</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">file_or_path</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">master_only</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">True</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">global_master</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">False</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="_modules/torch_xla/core/xla_model.html#save"><span class="viewcode-link"><span class="pre">[源代码]</span></span></a><a class="headerlink" href="#torch_xla.core.xla_model.save" title="永久链接到此定义">¶</a>
将输入数据保存到文件中。
保存的数据在保存之前会被传输到 PyTorch CPU 设备,因此之后的 <cite>torch.load()</cite> 将加载 CPU 数据。使用视图时必须小心。建议您在加载张量并将其移动到目标设备后重新创建视图,而不是保存视图。
- 参数
data – 要保存的输入数据。Python 对象(列表、元组、集合、字典等)的任意嵌套组合。
file_or_path – 数据保存操作的目标位置。可以是文件路径或 Python 文件对象。如果 <cite>master_only</cite> 为 <code class="docutils literal notranslate"><span class="pre">False</span></code>,则路径或文件对象必须指向不同的目标位置,否则来自同一主机的所有写入都将相互覆盖。
master_only (bool, 可选) – 是否只有主设备应该保存数据。如果为 False,则 <cite>file_or_path</cite> 参数对于参与复制的每个序号应该是不同的文件或路径,否则同一主机上的所有副本都将写入同一位置。默认值:True
global_master (bool, 可选) – 当 <code class="docutils literal notranslate"><span class="pre">master_only</span></code> 为 <code class="docutils literal notranslate"><span class="pre">True</span></code> 时,此标志控制是每个主机的主设备(如果 <code class="docutils literal notranslate"><span class="pre">global_master</span></code> 为 <code class="docutils literal notranslate"><span class="pre">False</span></code>)保存内容,还是只有全局主设备(序号 0)保存内容。默认值:False
sync (bool, 可选) – 是否在保存张量后同步所有副本。如果为 True,则所有副本都必须调用 <cite>xm.save</cite>,否则主进程将挂起。
- <span class="sig-prename descclassname"><span class="pre">torch_xla.core.xla_model.</span></span><span class="sig-name descname"><span class="pre">rendezvous</span></span><span class="sig-paren">(<em class="sig-param"><span class="n"><span class="pre">tag</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">payload</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">b''</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">replicas</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">[]</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="_modules/torch_xla/core/xla_model.html#rendezvous"><span class="viewcode-link"><span class="pre">[源代码]</span></span></a><a class="headerlink" href="#torch_xla.core.xla_model.rendezvous" title="永久链接到此定义">¶</a>
等待所有网格客户端到达指定的集合点。
注意:PJRT 不支持 XRT 网格服务器,因此这实际上是 <cite>xla_rendezvous</cite> 的别名。
- 参数
tag (string) – 要加入的集合点的名称。
payload (bytes, 可选) – 要发送到集合点的负载。
replicas (list, python:int) – 参与集合点的副本序号。空表示网格中的所有副本。默认值:[]
- 返回值
所有其他核心交换的负载,核心序号 <cite>i</cite> 的负载位于返回元组中的位置 <cite>i</cite>。
- <span class="sig-prename descclassname"><span class="pre">torch_xla.core.xla_model.</span></span><span class="sig-name descname"><span class="pre">do_on_ordinals</span></span><span class="sig-paren">(<em class="sig-param"><span class="n"><span class="pre">target</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">data</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">()</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">ordinals</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">(0,)</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="_modules/torch_xla/core/xla_model.html#do_on_ordinals"><span class="viewcode-link"><span class="pre">[源代码]</span></span></a><a class="headerlink" href="#torch_xla.core.xla_model.do_on_ordinals" title="永久链接到此定义">¶</a>
仅在给定的一组序号上运行函数。
- 参数
target (callable) – 要在 <cite>ordinals</cite> 上运行的函数。
data – <cite>target</cite> 函数的任何输入数据,其中包含张量。<cite>target</cite> 函数使用的所有 XLA 张量都必须在此参数中传递。函数使用的所有其他数据都可以像往常一样由 Python 解释器捕获。默认值:()
ordinals (list, python:int) – 应该运行 <cite>target</cite> 函数的序号列表/集合。默认值:(0,)
- 返回值
在运行 <cite>target</cite> 函数的序号中,返回函数返回值,否则返回 <cite>None</cite>。
- <span class="sig-prename descclassname"><span class="pre">torch_xla.core.xla_model.</span></span><span class="sig-name descname"><span class="pre">mesh_reduce</span></span><span class="sig-paren">(<em class="sig-param"><span class="n"><span class="pre">tag</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">data</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">reduce_fn</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="_modules/torch_xla/core/xla_model.html#mesh_reduce"><span class="viewcode-link"><span class="pre">[源代码]</span></span></a><a class="headerlink" href="#torch_xla.core.xla_model.mesh_reduce" title="永久链接到此定义">¶</a>
执行图外客户端网格约简。
- 参数
tag (string) – 要加入的集合点的名称。
data – 要约简的数据。<cite>reduce_fn</cite> 可调用对象将接收一个列表,其中包含来自所有网格客户端进程(每个核心一个)的相同数据的副本。
reduce_fn (callable) – 接收 <cite>data</cite> 类对象列表并返回约简结果的函数。
- 返回值
约简后的值。
- <span class="sig-prename descclassname"><span class="pre">torch_xla.core.xla_model.</span></span><span class="sig-name descname"><span class="pre">set_rng_state</span></span><span class="sig-paren">(<em class="sig-param"><span class="n"><span class="pre">seed</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">device</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="_modules/torch_xla/core/xla_model.html#set_rng_state"><span class="viewcode-link"><span class="pre">[源代码]</span></span></a><a class="headerlink" href="#torch_xla.core.xla_model.set_rng_state" title="永久链接到此定义">¶</a>
设置随机数生成器状态。
- 参数
seed (python:integer) – 要设置的状态。
device (string, 可选) – 需要设置 RNG 状态的设备。如果缺少,则将设置默认设备种子。
- <span class="sig-prename descclassname"><span class="pre">torch_xla.core.xla_model.</span></span><span class="sig-name descname"><span class="pre">get_rng_state</span></span><span class="sig-paren">(<em class="sig-param"><span class="n"><span class="pre">device</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="_modules/torch_xla/core/xla_model.html#get_rng_state"><span class="viewcode-link"><span class="pre">[源代码]</span></span></a><a class="headerlink" href="#torch_xla.core.xla_model.get_rng_state" title="永久链接到此定义">¶</a>
获取当前运行的随机数生成器状态。
- 参数
device (string, 可选) – 需要检索其 RNG 状态的设备。如果缺少,则将设置默认设备种子。
- 返回值
RNG 状态,为整数。
- <span class="sig-prename descclassname"><span class="pre">torch_xla.core.xla_model.</span></span><span class="sig-name descname"><span class="pre">get_memory_info</span></span><span class="sig-paren">(<em class="sig-param"><span class="n"><span class="pre">device</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="_modules/torch_xla/core/xla_model.html#get_memory_info"><span class="viewcode-link"><span class="pre">[源代码]</span></span></a><a class="headerlink" href="#torch_xla.core.xla_model.get_memory_info" title="永久链接到此定义">¶</a>
检索设备内存信息。
- 参数
device (string) – 请求其内存信息的设备。
- 返回值
一个字典,包含 kb_free(可用内存,单位为 KB)和 kb_total(总内存,单位为 KB)键。
- torch_xla.core.xla_model.get_stablehlo(tensors=None) str [source]¶
以字符串格式获取计算图的 StableHLO。
如果 tensors 不为空,则将转储以 tensors 作为输出的图。如果 tensors 为空,则将转储整个计算图。TODO(lsy323):当 tensors 为空时,一些中间张量也将作为输出进行转储。需要进一步调查。
对于推理图,建议将模型输出传递给 tensors。对于训练图,识别“输出”并不直接。建议使用空的 tensors。
要在 StableHLO 中启用源代码行信息,请设置环境变量 XLA_HLO_DEBUG=1。
- 参数
tensors (list[torch.Tensor], 可选) – 表示 StableHLO 图的输出/根的张量。
- 返回值
字符串格式的 StableHLO 模块。
- torch_xla.core.xla_model.get_stablehlo_bytecode(tensors=None) bytes [source]¶
以字节码格式获取计算图的 StableHLO。
如果 tensors 不为空,则将转储以 tensors 作为输出的图。如果 tensors 为空,则将转储整个计算图。TODO(lsy323):当 tensors 为空时,一些中间张量也将作为输出进行转储。需要进一步调查。
对于推理图,建议将模型输出传递给 tensors。对于训练图,识别“输出”并不直接。建议使用空的 tensors。
- 参数
tensors (list[torch.Tensor], 可选) – 表示 StableHLO 图的输出/根的张量。
- 返回值
字节码格式的 StableHLO 模块。
- torch_xla.core.functions.all_reduce(reduce_type, value, scale=1.0, groups=None)[source]¶
对输入张量执行就地 reduce 操作。
这与 xm.all_reduce() 相同,但支持自动求导。
- 参数
reduce_type (string) –
REDUCE_SUM
、REDUCE_MUL
、REDUCE_AND
、REDUCE_OR
、REDUCE_MIN
和REDUCE_MAX
中的一个。value (torch.Tensor) – 要对其执行 all reduce 操作的张量。
scale (python:float) – 在归约后应用的默认缩放值。默认值:1.0
groups (列表, 可选) –
一个列表列表,表示 all_reduce() 操作的副本组。示例:[[0, 1, 2, 3], [4, 5, 6, 7]]
定义了两个组,一个包含 [0, 1, 2, 3] 副本,另一个包含 [4, 5, 6, 7] 副本。如果为 None,则只有一个组,其中包含所有副本。
- 返回值
跨选定副本的 reduce 值。
- torch_xla.core.functions.all_gather(value, dim=0)[source]¶
沿给定维度执行全部收集操作。
这与 xm.all_gather() 相同,但支持自动求导。
- 参数
value (torch.Tensor) – 输入张量。
dim (python:int) – 收集维度。默认值:0
- 返回值
在
dim
维度上具有来自参与副本的所有值的张量。
- torch_xla.core.functions.nms(boxes, scores, score_threshold, iou_threshold, output_size)[source]¶
执行非极大值抑制 (Non Maximal Suppression) 操作。
- 参数
boxes (torch.Tensor) – 形状为 [N, 4] 的 torch.Tensor,以 (y0, x0, y1, x1) 形式列出边界框坐标。
scores (torch.Tensor) – 形状为 [N] 的 torch.Tensor,列出每个边界框的分数。
score_threshold (torch.Tensor) – 边界框被视为有效的最低分数。
iou_threshold (torch.Tensor) – 触发重叠逻辑的最小 IOU(交并比)分数。
output_size (python:int) – 返回索引的最大数量(必须小于或等于 N)。
- 返回值
一个 torch.Tensor 元组,第一个元素是选定的边界框索引,第二个元素是有效边界框的数量。
distributed¶
- class torch_xla.distributed.parallel_loader.ParallelLoader(loader, devices, batchdim=0, batches_per_execution=1, loader_prefetch_size=8, device_prefetch_size=4, host_to_device_transfer_threads=1, input_sharding=None)[source]¶
使用后台数据上传包装现有的 PyTorch DataLoader。
- 参数
loader (
torch.utils.data.DataLoader
) – 要包装的 PyTorch DataLoader。devices (torch.device…) – 数据要发送到的设备列表。loader 返回的第 i 个样本将被发送到 devices[i % len(devices)]。
batchdim (python:int, 可选) – 保存批大小的维度。默认值:0
loader_prefetch_size (python:int, 可选) – 从 loader 读取样本的线程使用的队列的最大容量,这些样本将由将数据上传到设备的工作线程进行处理。默认值:8
device_prefetch_size (python:int, 可选) – 每个设备队列的最大大小,工作线程将已发送到设备的张量存放到这些队列中。默认值:4
host_to_device_transfer_threads (python:int, 可选) – 并行工作以将数据从加载程序队列传输到设备队列的线程数。默认值:1
input_sharding (ShardingSpec, 可选) – 加载后应用于兼容输入张量的切片规范。默认值:None
- torch_xla.distributed.xla_multiprocessing.spawn(fn, args=(), nprocs=None, join=True, daemon=False, start_method='spawn')[source]¶
启用基于多处理的复制。
- 参数
fn (可调用对象) – 为参与复制的每个设备调用的函数。该函数将使用第一个参数作为复制过程中进程的全局索引,后跟传递给 args 的参数来调用。
args (tuple) – fn 的参数。默认值:空元组
nprocs (python:int) – 用于复制的进程/设备的数量。目前,如果指定,可以是 1 或最大设备数量。
join (bool) – 调用是否应阻塞等待已生成进程的完成。默认值:True
daemon (bool) – 生成的进程是否应设置 daemon 标志(请参阅 Python 多处理 API)。默认值:False
start_method (string) – Python multiprocessing 进程创建方法。默认值:spawn
- 返回值
torch.multiprocessing.spawn API 返回的相同对象。如果 nprocs 为 1,则将直接调用 fn 函数,并且 API 将返回 None。
- class torch_xla.distributed.xla_multiprocessing.MpModelWrapper(model)[source]¶
包装模型以在使用 fork 方法时最大程度地减少主机内存使用量。
此类应与 spawn(…, start_method=’fork’) API 一起使用,以最大程度地减少主机内存的使用。无需在每个多处理进程上创建模型(这会复制模型的初始主机内存),而是在全局范围内创建一次模型,然后在 spawn() 目标函数内部将其移入每个设备。示例
WRAPPED_MODEL = xmp.MpModelWrapper(MyNetwork()) def _mp_fn(index, ...): device = xm.xla_device() model = WRAPPED_MODEL.to(device) ... xmp.spawn(_mp_fn, ..., start_method='fork')
这种方法有两个优点。首先,它只使用内存页的一个副本托管原始模型权重;其次,它通过在过程中降低系统内存的负载,将包装模型的移动序列化到每个设备中。
- class torch_xla.distributed.xla_multiprocessing.MpSerialExecutor[source]¶
在多核进程之间以序列化方式运行函数的实用程序。
示例
# At global scope. SERIAL_EXEC = xmp.MpSerialExecutor() def load_dataset(path): return maybe_download_and_load(path) def _mp_fn(index, ...): # Avoid all cores downloading the same data with the serial executor. dataset = SERIAL_EXEC.run(lambda: load_dataset('/tmp/mnist-data')) ... xmp.spawn(_mp_fn, ...)
utils¶
- class torch_xla.utils.utils.SampleGenerator(data, sample_count)[source]¶
迭代器,返回给定输入数据的多个样本。
可以代替 PyTorch DataLoader 来生成合成数据。
- 参数
data – 应在每个迭代器步骤返回的数据。
sample_count – 要返回的 data 样本的最大数量。
- torch_xla.utils.serialization.save(data, path, master_only=True, global_master=False)[source]¶
将输入数据保存到文件中。
保存的数据在保存之前会被传输到 PyTorch CPU 设备,因此之后的 <cite>torch.load()</cite> 将加载 CPU 数据。使用视图时必须小心。建议您在加载张量并将其移动到目标设备后重新创建视图,而不是保存视图。
- 参数
data – 要保存的输入数据。Python 对象(列表、元组、集合、字典等)的任意嵌套组合。
path – 数据保存操作的目标文件。如果 master_only 为
False
,则路径必须指向不同的目标,否则来自同一主机的所有写入都将相互覆盖。master_only (bool, 可选) – 是否只有主设备应该保存数据。如果为 False,则 path 参数对于参与复制的每个序号都应该是不同的路径,否则同一主机上的所有副本都将写入同一位置。默认值:True
global_master (bool, 可选) – 当 <code class="docutils literal notranslate"><span class="pre">master_only</span></code> 为 <code class="docutils literal notranslate"><span class="pre">True</span></code> 时,此标志控制是每个主机的主设备(如果 <code class="docutils literal notranslate"><span class="pre">global_master</span></code> 为 <code class="docutils literal notranslate"><span class="pre">False</span></code>)保存内容,还是只有全局主设备(序号 0)保存内容。默认值:False
测试¶
PyTorch/XLA 入门指南¶
本文档概述了 PyTorch XLA,并通过几个示例说明了如何转换 PyTorch 代码以在 XLA 设备(例如 TPU)上运行。这不是一个完整的解决方案,可能需要根据具体代码进行其他更改。但是,本文档应该作为转换过程的起点。
对一些 XLA 细节的基本高级理解¶
- 本节简要概述了 PyTorch XLA 的基本细节,
这应该有助于读者更好地理解代码所需的修改和优化。它是对此处所述 API 指南的补充。
与逐行执行代码并且在获取 PyTorch 张量 的值之前不会阻塞执行的常规 PyTorch 不同,PyTorch XLA 的工作方式有所不同。它遍历 Python 代码并记录对 (PyTorch)XLA 张量 在中间表示 (IR) 图中的操作,直到遇到障碍(下文讨论)。生成 IR 图的过程称为跟踪(LazyTensor 跟踪或代码跟踪)。然后,PyTorch XLA 将 IR 图转换为称为 HLO(高级操作码)的低级机器可读格式。HLO 是特定于 XLA 编译器的计算表示形式,允许它为其运行的硬件生成高效代码。HLO 被馈送到 XLA 编译器以进行编译和优化。然后,编译结果由 PyTorch XLA 缓存,以便在以后需要时重复使用。图的编译在主机(CPU)上完成,主机是运行 Python 代码的机器。如果有多个 XLA 设备,则主机将分别为每个设备编译代码,除非使用 SPMD(单程序多数据)。例如,v4-8 有一台主机和四台设备。在这种情况下,主机将分别为四台设备编译代码。对于 pod 切片,当有多个主机时,每个主机都对其连接的 XLA 设备进行编译。如果使用 SPMD,则代码只在每个主机上针对所有设备编译一次(对于给定的形状和计算)。
有关更多详细信息和示例,请参阅LazyTensor 指南。
仅当需要张量的值时,才会执行 IR 图中的操作。这称为张量的评估或物化。有时,这也称为延迟评估,它可以显著提高性能。
Pytorch XLA 中的*同步*操作(如打印、日志记录、检查点或回调)会阻塞跟踪并导致执行速度变慢。如果操作需要 XLA 张量的特定值(例如 print(xla_tensor_z)
),则跟踪会阻塞,直到主机可以使用该张量的值。请注意,只会执行负责计算该张量值的部分图。这些操作不会切割 IR 图,但它们会通过 TransferFromDevice
触发主机与设备之间的通信,从而导致性能下降。
*障碍*是一种特殊指令,它告诉 XLA 执行 IR 图并实例化张量。这意味着将评估 PyTorch XLA 张量,并且主机将可以使用结果。Pytorch XLA 中面向用户的障碍是 xm.mark_step(),它会破坏 IR 图并导致在 XLA 设备上执行代码。xm.mark_step
的一个关键特性是,与同步操作不同,它不会在设备执行图时阻塞进一步的跟踪。但是,它确实会阻止访问正在物化的张量的值。
LazyTensor 指南中的示例说明了在添加两个张量的简单情况下会发生什么。现在,假设我们有一个 for 循环,它添加 XLA 张量并在稍后使用该值
for x, y in tensors_on_device:
z += x + y
如果没有障碍,Python 跟踪将生成一个单一图,该图将包装 len(tensors_on_device)
次张量的加法。这是因为跟踪不会捕获 for
循环,因此循环的每次迭代都会创建一个与计算 z += x+y
相对应的新子图,并将其添加到图中。下面是 len(tensors_on_device)=3
时的示例。

但是,在循环末尾引入障碍将导致生成一个更小的图,该图将在 for
循环内的第一次传递期间编译一次,并在接下来的 len(tensors_on_device)-1
次迭代中重复使用。障碍将向跟踪发出信号,表明到目前为止跟踪的图可以提交执行,如果之前见过该图,则将重复使用缓存的已编译程序。
for x, y in tensors_on_device:
z += x + y
xm.mark_step()
在这种情况下,将有一个使用 len(tensors_on_device)=3
次的小图。

需要强调的是,在 PyTorch XLA 中,如果 for 循环末尾有障碍,则会跟踪 for 循环内的 Python 代码,并为每次迭代构造一个新图。这可能是一个严重的性能瓶颈。
当对相同形状的张量进行相同计算时,可以重复使用 XLA 图。如果输入或中间张量的形状发生变化,则 XLA 编译器将使用新的张量形状重新编译一个新图。这意味着,如果您有动态形状,或者您的代码不重复使用张量图,则在 XLA 上运行模型将不适合该用例。将输入填充到固定形状中可以作为一种避免动态形状的选择。否则,编译器将花费大量时间来优化和融合不会再次使用的操作。
图大小和编译时间之间的权衡也很重要。如果有一个大型 IR 图,则 XLA 编译器可能会花费大量时间来优化和融合操作。这可能导致编译时间非常长。但是,由于在编译期间执行了优化,因此以后的执行可能会快得多。
有时,使用 xm.mark_step()
打破 IR 图是值得的。如上所述,这将产生一个可以在以后重复使用的较小图。但是,使图变小会减少 XLA 编译器可以执行的优化。
另一个需要考虑的重点是 MPDeviceLoader。一旦您的代码在 XLA 设备上运行,请考虑使用 XLA MPDeviceLoader
包装 torch 数据加载器,该加载器会将数据预加载到设备以提高性能,并包含 xm.mark_step()
。后者会自动中断对数据批次的迭代并将它们发送以供执行。请注意,如果您没有使用 MPDeviceLoader,则可能需要在 optimizer_step()
中设置 barrier=True
以在运行训练作业时启用 xm.mark_step()
,或者显式添加 xm.mark_step()
。
TPU 设置¶
使用基本映像创建 TPU 以使用夜间版本,或通过指定 RUNTIME_VERSION
从稳定版本创建 TPU。
export ZONE=us-central2-b
export PROJECT_ID=your-project-id
export ACCELERATOR_TYPE=v4-8 # v4-16, v4-32, …
export RUNTIME_VERSION=tpu-vm-v4-pt-2.0 # or tpu-vm-v4-base
export TPU_NAME=your_tpu_name
gcloud compute tpus tpu-vm create ${TPU_NAME} \
--zone=${ZONE} \
--accelerator-type=${ACCELERATOR_TYPE} \
--version=${RUNTIME_VERSION} \
--subnetwork=tpusubnet
如果您有一个主机虚拟机(例如 v4-8),则可以 ssh 到您的虚拟机并直接从虚拟机运行以下命令。否则,对于 TPU pod,您可以使用 --worker=all --command=""
,类似于
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--zone=us-central2-b \
--worker=all \
--command="pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-nightly-cp38-cp38-linux_x86_64.whl"
接下来,如果您使用的是基本映像,请安装夜间软件包和所需的库
pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-nightly-cp38-cp38-linux_x86_64.whl
pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly-cp38-cp38-linux_x86_64.whl
sudo apt-get install libopenblas-dev -y
sudo apt-get update && sudo apt-get install libgl1 -y # diffusion specific
将代码转换为 PyTorch XLA¶
修改代码的一般准则
将
cuda
替换为xm.xla_device()
删除进度条,打印将访问 XLA 张量值
减少将访问 XLA 张量值的日志记录和回调
使用 MPDeviceLoader 包装数据加载器
进行性能分析以进一步优化代码
请记住:每个案例都是唯一的,因此您可能需要针对每个案例采取不同的措施。
示例 1:在单个 TPU 设备上使用 PyTorch Lightning 进行稳定扩散推理¶
作为第一个示例,请考虑 PyTorch Lightning 中稳定扩散模型的 推理代码,该代码可以从命令行运行,如下所示
python scripts/txt2img.py --prompt "a photograph of an astronaut riding a horse"
供您参考,可以在 此处 找到下面描述的修改的差异。让我们逐步介绍它们。如上文的一般准则中所述,请从与 cuda
设备相关的更改开始。此推理代码编写为在 GPU 上运行,并且可以在多个位置找到 cuda
。首先删除 此行 中的 model.cuda()
以及 此处 中的 precision_scope
,开始进行更改。此外,将 此行 中的 cuda
设备替换为 xla
设备,类似于以下代码
接下来,此特定模型配置正在使用 FrozenCLIPEmbedder
,因此我们还将修改 此行。为了简单起见,我们将在本教程中直接定义 device
,但您也可以将 device
值传递给函数。
import torch_xla.core.xla_model as xm
self.device = xm.xla_device()
代码中另一个包含 cuda 特定代码的位置是 DDIM 调度程序。在文件顶部添加 import torch_xla.core.xla_model as xm
,然后替换 这些 行
if attr.device != torch.device("cuda"):
attr = attr.to(torch.device("cuda"))
替换为
device = xm.xla_device()
attr = attr.to(torch.device(device))
接下来,您可以通过删除打印语句、禁用进度条以及减少或删除回调和日志记录来减少设备(TPU)和主机(CPU)之间的通信。这些操作需要设备停止执行,回退到 CPU,执行日志记录/回调,然后返回到设备。这可能是一个严重的性能瓶颈,尤其是在大型模型上。
进行这些更改后,代码将在 TPU 上运行。但是,性能会非常慢。这是因为 XLA 编译器试图构建一个包含推理步骤数(在本例中为 50 个)的单个(巨大)图,因为 for 循环中没有障碍。编译器难以优化图,这会导致性能显着下降。如上所述,使用障碍 (xm.mark_step()) 打破 for 循环将产生一个更小的图,编译器更容易优化该图。这还将允许编译器重用上一步中的图,从而提高性能。
现在,代码 已准备好 在合理的时间内在 TPU 上运行。可以通过 捕获配置文件 并进行进一步调查来进行更多优化和分析。但是,此处不作介绍。
注意:如果您在 v4-8 TPU 上运行,则有 4 个可用的 XLA (TPU) 设备。如上所述运行代码只会使用一个 XLA 设备。为了在所有 4 个设备上运行,您需要使用 xmp.spawn()
函数在所有设备上生成代码。我们将在下一个示例中讨论 xmp.spawn
。
示例 2:HF 稳定扩散推理¶
现在,考虑在 HuggingFace diffusers 库中使用 稳定扩散推理 来处理模型的 SD-XL 和 2.1 版本。供您参考,可以在此 存储库 中找到下面描述的更改。您可以克隆存储库并在 TPU 虚拟机上使用以下命令运行推理
(vm)$ git clone https://github.com/pytorch-tpu/diffusers.git
(vm)$ cd diffusers/examples/text_to_image/
(vm)$ python3 inference_tpu_single_device.py
由于没有 bf16 版本的 SD-XL 模型可用,因此您可以使用 XLA_USE_BF16=1
标志将所有值转换为 bf16 并加快训练速度。
(vm)$ XLA_USE_BF16=1 python3 inference_tpu_single_device.py # uses sd-xl version
或
(vm)$ python3 inference_tpu_multidevice.py # uses 2.1 version
(模型的 2.1 版本中已包含 torch.bfloat16
)。
警告:请注意 此处 突出显示的注意事项。
在单个 TPU 设备上运行¶
本节介绍需要对 文本到图像推理示例 代码进行哪些更改才能在 TPU 上运行它。
原始代码使用 Lora 进行推理,但本教程不会使用它。相反,我们在初始化管道时会将 model_id
参数设置为 stabilityai/stable-diffusion-xl-base-0.9
。我们还将使用默认调度程序 (DPMSolverMultistepScheduler)。但是,也可以对其他调度程序进行类似的更改。
git clone https://github.com/huggingface/diffusers
cd diffusers
pip install . # pip install -e .
cd examples/text_to_image/
pip install -r requirements.txt
pip install invisible_watermark transformers accelerate safetensors
(如果找不到 accelerate
,请注销,然后重新登录。)
登录 HF 并同意模型卡上的 sd-xl 0.9 许可证。接下来,转到 帐户→设置→访问令牌 并生成一个新令牌。复制令牌并在您的虚拟机上使用该特定令牌值运行以下命令
(vm)$ huggingface-cli login --token _your_copied_token__
HuggingFace 自述文件提供了编写为在 GPU 上运行的 PyTorch 代码。要在 TPU 上运行它,第一步是将 CUDA 设备更改为 XLA 设备。这可以通过将 pipe.to("cuda")
行替换为以下行来完成
import torch_xla.core.xla_model as xm
device = xm.xla_device()
pipe.to(device)
此外,请务必注意,第一次使用 XLA 运行推理时,编译将花费很长时间。例如,HuggingFace 的稳定扩散 XL 模型推理的编译时间可能需要大约一个小时,而实际推理可能只需要 5 秒,具体取决于批处理大小。同样,GPT-2 模型可能需要大约 10-15 分钟来编译,之后训练周期时间会变得快得多。这是因为 XLA 会构建将要执行的计算图,然后针对其运行的特定硬件优化该图。但是,一旦图编译完成,就可以将其重用于后续推理,这将快得多。因此,如果您只运行一次推理,则可能无法从使用 XLA 中受益。但是,如果您多次运行推理,或者您在提示列表上运行推理,您将在前几次推理后开始看到 XLA 的优势。例如,如果您在包含 10 个提示的列表上运行推理,则第一次推理(可能是两次1)可能需要很长时间才能编译,但剩余的推理步骤会快得多。这是因为 XLA 将重用它为第一次推理编译的图。
如果您尝试在不进行任何其他更改的情况下运行代码,您会注意到编译时间非常长(>6 小时)。这是因为 XLA 编译器尝试为所有调度程序步骤构建一个图,类似于我们在上一个示例中讨论的内容。为了使代码运行得更快,我们需要使用 xm.mark_step()
将图分成更小的部分,并在后续步骤中重用它们。这发生在 函数 pipe.__call__
内部的 这些行 中。禁用进度条、删除回调并在 for 循环末尾添加 xm.mark_step()
可以显着加快代码速度。此 提交 中提供了更改。
此外,默认使用 DPMSolverMultistepScheduler 调度器的 self.scheduler.step()
函数存在一些问题,这些问题在PyTorch XLA 注意事项中有所描述。 此函数中的 .nonzero()
和 .item()
调用会向 CPU 发送张量评估请求,从而触发设备与主机之间的通信。 这并不理想,因为它可能会降低代码的运行速度。 在这种特殊情况下,我们可以通过将索引直接传递给函数来避免这些调用。 这将阻止函数向 CPU 发送请求,并将提高代码的性能。 此提交中提供了更改。 现在,代码可以在 TPU 上运行了。
分析和性能分析¶
为了进一步调查模型的性能,我们可以使用分析指南对其进行分析。 根据经验,分析脚本应使用适合内存的最大批处理大小来运行,以实现最佳内存使用。 它还有助于将代码的跟踪与设备执行重叠,从而实现更优化的设备使用。 分析的持续时间应足够长,以至少捕获一个步骤。 模型在 TPU 上的良好性能意味着设备与主机之间的通信最少,并且设备在没有空闲时间的情况下不断运行进程。
按照指南中的说明启动 inference_tpu_*.py
文件中的服务器并运行 capture_profile.py
脚本将为我们提供有关设备上运行的进程的信息。 当前,仅分析一个 XLA 设备。 为了更好地理解 TPU 空闲时间(配置文件中的间隙),应将分析跟踪(xp.Trace()
)添加到代码中。 xp.Trace()
测量在主机上跟踪用跟踪包装的 Python 代码所需的时间。 对于此示例,在管道和U-net 模型中添加了 xp.Trace()
跟踪,以测量在主机 (CPU) 上运行代码特定部分所需的时间。
如果配置文件中的间隙是由于主机上发生的 Python 代码跟踪造成的,那么这可能是一个瓶颈,并且没有可以直接进行的进一步优化。 否则,应进一步分析代码以了解注意事项并进一步提高性能。 请注意,您不能在调用 xm.mark_step()
的代码部分使用 xp.Trace()
包装。
为了说明这一点,我们可以查看已按照分析指南上传到 TensorBoard 的已捕获配置文件。
从 Stable Diffusion 模型版本 2.1 开始
如果我们在不插入任何跟踪的情况下捕获配置文件,我们将看到以下内容

v4-8 上的单个 TPU 设备(有两个核心)似乎很忙。 除了中间的一个小间隙外,它们的用途中没有明显的间隙。 如果我们向上滚动以尝试查找哪个进程占用了主机,我们将找不到任何信息。 因此,我们将像对 2.1 版本所做的那样,将 xp.traces
添加到管道文件以及 U-net 函数中。 后者对于此特定用例可能没有用,但它确实演示了如何在不同位置添加跟踪以及它们的信息如何在 TensorBoard 中显示。
如果我们添加跟踪并使用可以放入设备的最大批处理大小(在本例中为 32)重新捕获配置文件,我们将看到设备中的间隙是由主机上运行的 Python 进程引起的。


我们可以使用适当的工具放大时间线,看看在那段时间内哪个进程正在运行。 这就是 Python 代码跟踪发生在主机上的时间,此时我们无法进一步改进跟踪。
现在,让我们检查模型的 XL 版本并执行相同的操作。 我们将像对 2.1 版本所做的那样,将跟踪添加到管道文件中,并捕获配置文件。

这一次,除了由 pipe_watermark
跟踪引起的大间隙之外,此循环内的推理步骤之间还有许多小间隙。
首先仔细查看由 pipe_watermark
引起的大间隙。 间隙之前是 TransferFromDevice
,这表明主机上正在发生某些事情,正在等待计算完成才能继续。 查看水印代码,我们可以看到张量被传输到 cpu 并转换为 numpy 数组,以便稍后使用 cv2
和 pywt
库进行处理。 由于这部分不容易优化,因此我们将保持原样。
现在,如果我们放大循环,我们可以看到循环内的图被分解成更小的部分,因为发生了 TransferFromDevice
操作。

如果我们研究 U-Net 函数和调度器,我们可以看到 U-Net 代码不包含 PyTorch/XLA 的任何优化目标。 但是,在 scheduler.step 中有 .item()
和 .nonzero()
调用。 我们可以重写函数以避免这些调用。 如果我们解决这个问题并重新运行配置文件,我们将看不到太大差异。 但是,由于我们减少了引入较小图形的设备与主机之间的通信,因此我们允许编译器更好地优化代码。 函数 scale_model_input 也有类似的问题,我们可以通过对 step
函数进行上述更改来解决这些问题。 总的来说,由于许多间隙是由 Python 级代码跟踪和图形构建引起的,因此在当前版本的 PyTorch XLA 中无法优化这些间隙,但我们可能会在将来在 PyTorch XLA 中启用 Dynamo 时看到改进。
在多个 TPU 设备上运行¶
要使用多个 TPU 设备,可以使用 xmp.spawn
函数将您在单个设备上运行的函数衍生到多个设备。 xmp.spawn
函数将在多个 TPU 设备上启动进程,并在需要时同步它们。 这可以通过将 index
参数传递给在单个设备上运行的函数来完成。 例如,
import torch_xla.distributed.xla_multiprocessing as xmp
def my_function(index):
# function that runs on a single device
xmp.spawn(my_function, args=(0,), nprocs=4)
在此示例中,my_function
函数将在 v4-8 上的 4 个 TPU 设备上衍生,每个设备都被分配了一个从 0 到 3 的索引。
此文件说明了如何使用 xmp.spawn 在多个 TPU 设备上运行稳定扩散 2.1 版本。 对于此版本,对管道文件进行了类似于上述更改的操作。
在 Pod 上运行¶
一旦您拥有在单个主机设备上运行的代码,就无需进一步更改。 您可以按照以下说明创建 TPU pod。 然后使用以下命令运行您的脚本
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--zone=${ZONE} \
--worker=all \
--command="python3 your_script.py"
- 1
0 和 1 是 XLA 中的幻数,在 HLO 中被视为常量。 因此,如果代码中存在可以生成这些值的随机数生成器,则代码将分别为每个值编译。 可以使用
XLA_NO_SPECIAL_SCALARS=1
环境变量禁用此功能。
故障排除¶
请注意,本节中的信息可能会在未来版本的 *PyTorch/XLA* 软件中删除,因为其中许多信息是特定于给定内部实现的,而这些实现可能会发生变化。
健全性检查¶
在执行任何深入调试之前,我们想对安装的 PyTorch/XLA 进行健全性检查。
检查 PyTorch/XLA 版本¶
PyTorch 和 PyTorch/XLA 版本应匹配。 有关可用版本的更多详细信息,请查看我们的README。
vm:~$ python
>>> import torch
>>> import torch_xla
>>> print(torch.__version__)
2.1.0+cu121
>>> print(torch_xla.__version__)
2.1.0
执行简单计算¶
vm:~$ export PJRT_DEVICE=TPU
vm:~$ python3
>>> import torch
>>> import torch_xla.core.xla_model as xm
>>> t1 = torch.tensor(100, device=xm.xla_device())
>>> t2 = torch.tensor(200, device=xm.xla_device())
>>> print(t1 + t2)
tensor(300, device='xla:0')
使用虚假数据运行 Resnet¶
对于每晚版本
vm:~$ git clone https://github.com/pytorch/xla.git
vm:~$ python xla/test/test_train_mp_imagenet.py --fake_data
对于发布版本 x.y
,您要使用分支 rx.y
。 例如,如果您安装了 2.1 版本,则应该执行
vm:~$ git clone --branch r2.1 https://github.com/pytorch/xla.git
vm:~$ python xla/test/test_train_mp_imagenet.py --fake_data
如果您能让 resnet 运行,我们可以得出结论,torch_xla 已正确安装。
性能调试¶
为了诊断性能问题,我们可以使用 *PyTorch/XLA* 提供的执行指标和计数器。当模型速度慢时,首先要检查的是生成指标报告。
指标报告对于诊断问题非常有帮助。 如果您有指标报告,请尝试将其包含在发送给我们的错误报告中。
PyTorch/XLA 调试工具¶
您可以通过设置 PT_XLA_DEBUG=1
来启用 PyTorch/XLA 调试工具,该工具提供了一些有用的调试功能。
PyTorch/XLA + Dynamo 调试工具¶
您可以通过设置 XLA_DYNAMO_DEBUG=1
来启用 PyTorch/XLA + Dynamo 调试工具。
执行自动指标分析¶
调试工具将分析指标报告并提供摘要。 一些示例输出将是
pt-xla-profiler: CompileTime too frequent: 21 counts during 11 steps
pt-xla-profiler: TransferFromDeviceTime too frequent: 11 counts during 11 steps
pt-xla-profiler: Op(s) not lowered: aten::_ctc_loss, aten::_ctc_loss_backward, Please open a GitHub issue with the above op lowering requests.
pt-xla-profiler: CompileTime too frequent: 23 counts during 12 steps
pt-xla-profiler: TransferFromDeviceTime too frequent: 12 counts during 12 steps
编译和执行分析¶
调试工具将分析模型的每次编译和执行。以下是一些示例输出
Compilation Analysis: ================================================================================
Compilation Analysis: Compilation Cause
Compilation Analysis: user mark_step
Compilation Analysis: Graph Info:
Compilation Analysis: Graph Hash: 537d4b0264b029688281412214d252e9
Compilation Analysis: Number of Graph Inputs: 588
Compilation Analysis: Number of Graph Outputs: 320
Compilation Analysis: Python Frame Triggered Execution:
Compilation Analysis: mark_step (/workspaces/dk2/pytorch/xla/torch_xla/core/xla_model.py:840)
Compilation Analysis: broadcast_master_param (/workspaces/dk2/pytorch/xla/torch_xla/core/xla_model.py:1230)
Compilation Analysis: train_imagenet (/workspaces/dk2/pytorch/xla/test/test_train_mp_imagenet.py:261)
Compilation Analysis: _mp_fn (/workspaces/dk2/pytorch/xla/test/test_train_mp_imagenet.py:365)
Compilation Analysis: __call__ (/workspaces/dk2/pytorch/xla/torch_xla/_internal/pjrt.py:176)
Compilation Analysis: _thread_fn (/workspaces/dk2/pytorch/xla/torch_xla/_internal/pjrt.py:70)
Compilation Analysis: run (/usr/local/lib/python3.8/concurrent/futures/thread.py:57)
Compilation Analysis: _worker (/usr/local/lib/python3.8/concurrent/futures/thread.py:80)
Compilation Analysis: ..........
Compilation Analysis: --------------------------------------------------------------------------------
Compilation Analysis: ================================================================================
Execution Analysis: ================================================================================
Execution Analysis: Execution Cause
Execution Analysis: user mark_step
Execution Analysis: Graph Info:
Execution Analysis: Graph Hash: 537d4b0264b029688281412214d252e9
Execution Analysis: Number of Graph Inputs: 588
Execution Analysis: Number of Graph Outputs: 320
Execution Analysis: Python Frame Triggered Execution:
Execution Analysis: mark_step (/workspaces/dk2/pytorch/xla/torch_xla/core/xla_model.py:840)
Execution Analysis: broadcast_master_param (/workspaces/dk2/pytorch/xla/torch_xla/core/xla_model.py:1230)
Execution Analysis: train_imagenet (/workspaces/dk2/pytorch/xla/test/test_train_mp_imagenet.py:261)
Execution Analysis: _mp_fn (/workspaces/dk2/pytorch/xla/test/test_train_mp_imagenet.py:365)
Execution Analysis: __call__ (/workspaces/dk2/pytorch/xla/torch_xla/_internal/pjrt.py:176)
Execution Analysis: _thread_fn (/workspaces/dk2/pytorch/xla/torch_xla/_internal/pjrt.py:70)
Execution Analysis: run (/usr/local/lib/python3.8/concurrent/futures/thread.py:57)
Execution Analysis: _worker (/usr/local/lib/python3.8/concurrent/futures/thread.py:80)
Execution Analysis: ..........
Execution Analysis: --------------------------------------------------------------------------------
Execution Analysis: ================================================================================
编译/执行的一些常见原因是
用户手动调用
mark_step
。并行加载器 为每个 x 批次(可配置)调用
mark_step
。退出 分析器 StepTrace 区域。
Dynamo 决定编译/执行图。
用户尝试在
mark_step
之前访问张量的值(通常是由于日志记录)。
由 1-4 导致的执行是预期的,我们希望通过减少访问张量值的频率或在访问之前手动添加 mark_step
来避免 5。
用户应该预料到在前几个步骤中会看到 编译原因
+ 执行原因
对。在模型稳定之后,用户应该只会看到 执行原因
。为了有效地使用 PyTorch/XLA,我们期望为每个步骤运行相同的模型代码,并且每个图只编译一次。如果您不断看到 编译原因
,则应尝试按照本节的内容转储 IR/HLO,并比较每个步骤的图,以了解差异的来源。
下一节将解释如何获取和理解更详细的指标报告。
获取指标报告¶
在程序中加入以下代码行以生成报告
import torch_xla.debug.metrics as met
# For short report that only contains a few key metrics.
print(met.short_metrics_report())
# For full report that includes all metrics.
print(met.metrics_report())
理解指标报告¶
报告包括以下内容
我们发出 *XLA* 编译的次数以及花费的时间。
我们执行的次数以及花费的时间
我们创建/销毁的设备数据句柄的数量等。
此信息以样本百分位数的形式报告。例如
Metric: CompileTime
TotalSamples: 202
Counter: 06m09s401ms746.001us
ValueRate: 778ms572.062us / second
Rate: 0.425201 / second
Percentiles: 1%=001ms32.778us; 5%=001ms61.283us; 10%=001ms79.236us; 20%=001ms110.973us; 50%=001ms228.773us; 80%=001ms339.183us; 90%=001ms434.305us; 95%=002ms921.063us; 99%=21s102ms853.173us
我们还提供计数器,它们是跟踪内部软件状态的命名整型变量。例如
Counter: CachedSyncTensors
Value: 395
在本报告中,任何以 aten::
开头的计数器都表示 XLA 设备和 CPU 之间的上下文切换,这可能是模型代码中潜在的性能优化区域。
计数器有助于了解哪些操作被路由回 *PyTorch* 的 CPU 引擎。它们使用其 C++ 命名空间进行了完全限定
Counter: aten::nonzero
Value: 33
如果您看到 aten::
操作,而不是 nonzero
和 _local_scalar_dense
,这通常意味着 PyTorch/XLA 中缺少降低。请随时在 GitHub issues 上为此提交功能请求。
已知的性能注意事项¶
PyTorch/XLA 在语义上表现得像常规 PyTorch,并且 XLA 张量与 CPU 和 GPU 张量共享完整的张量接口。但是,XLA/硬件的限制和惰性求值模型表明,某些模式可能会导致性能不佳。
如果您的模型性能不佳,请记住以下注意事项
如果重新编译次数过多,XLA/TPU 的性能会下降。
XLA 编译成本很高。每次遇到新形状时,PyTorch/XLA 都会自动重新编译图。通常,模型应在几个步骤内稳定下来,并且您可以在训练的其余部分看到巨大的速度提升。
为了避免重新编译,不仅形状必须是常量,而且所有主机中 XLA 设备上的计算也必须是常量。
可能的原因:
直接或间接使用
nonzero
会引入动态形状;例如,掩码索引base[index]
,其中index
是掩码张量。步骤之间迭代次数不同的循环可能会导致不同的执行图,因此需要重新编译。
解决方案:
张量形状在迭代之间应保持相同,或者应使用少量形状变化。
尽可能将张量填充到固定大小。
某些操作没有到 XLA 的原生转换。
对于这些操作,PyTorch/XLA 会自动传输到 CPU 内存,在 CPU 上进行评估,然后将结果传输回 XLA 设备。在训练步骤中执行太多此类操作可能会导致速度显着下降。
可能的原因:
item()
操作显式要求评估结果。除非必要,否则不要使用它。
解决方案:
对于大多数操作,我们可以将它们降低到 XLA 来解决它。查看 指标报告部分以找出缺少的操作,并在 GitHub 上提交功能请求。
即使 PyTorch 张量已知为标量,也要避免使用
tensor.item()
。将其保留为张量并对其使用张量操作。在适用时使用
torch.where
来替换控制流。例如,在 clip_grad*norm* 中使用item()
的控制流是有问题的,并且会影响性能,因此我们通过调用torch.where
对 修补 了clip_grad_norm_
,这使我们的性能得到了显着提高。 .. 代码块:: python… else
device = parameters[0].device total_norm = torch.zeros([], device=device if parameters else None) for p in parameters
param_norm = p.grad.data.norm(norm_type) ** norm_type total_norm.add_(param_norm)
total_norm = (total_norm ** (1. / norm_type))
clip_coef = torch.tensor(max_norm, device=device) / (total_norm + 1e-6) for p in parameters
p.grad.data.mul_(torch.where(clip_coef < 1, clip_coef, torch.tensor(1., device=device)))
``torch_xla.distributed.data_parallel`` 中的迭代器可能会丢弃输入迭代器中的最后几个批次。
这是为了确保我们在所有 XLA 设备上完成相同数量的工作。
解决方案:
当数据集很小,并且步骤太少时,这可能会导致无操作时期。因此,在这些情况下最好使用小批量。
XLA 张量怪癖¶
**XLA 张量内部是不透明的。** XLA 张量始终表现为连续的,并且没有存储。网络不应尝试检查 XLA 张量的步幅。
**在保存 XLA 张量之前,应将其移动到 CPU。** 直接保存 XLA 张量会导致它们从保存它们的设备加载回来。如果在加载时设备不可用,则加载将失败。在保存 XLA 张量之前将其移动到 CPU 允许您决定将加载的张量放在哪个设备上。如果您想在没有 XLA 设备的机器上加载张量,则必须这样做。但是,在将 XLA 张量移动到 CPU 之前,应小心谨慎,因为跨设备类型移动张量不会保留视图关系。相反,应在加载张量后根据需要重建视图。
**使用 Python 的 copy.copy 复制 XLA 张量将返回深拷贝,而不是浅拷贝。** 使用 XLA 张量的视图来获取其浅拷贝。
**处理共享权重。** 模块可以通过将一个模块的参数设置为另一个模块来共享权重。模块权重的这种“绑定”应该在模块移动到 XLA 设备之后完成。否则,将在 XLA 设备上创建共享张量的两个独立副本。
更多调试工具¶
我们不希望用户使用本节中的工具来调试他们的模型。但是,当您提交错误报告时,我们可能会要求您提供这些信息,因为它们提供了指标报告中没有的额外信息。
print(torch_xla._XLAC._get_xla_tensors_text([res]))
其中res
是结果张量,打印出 IR。print(torch_xla._XLAC._get_xla_tensors_hlo([res]))
其中res
是结果张量,打印出生成的 XLA HLO。
请注意,必须在 mark_step()
之前调用这些函数,否则张量将已被物化。
环境变量¶
还有一些环境变量可以控制 *PyTorch/XLA* 软件栈的行为。
设置此类变量会导致不同程度的性能下降,因此应仅在调试时启用它们。
XLA_IR_DEBUG
:启用在创建 IR 节点时捕获 *Python* 堆栈跟踪,从而可以了解哪个 *PyTorch* 操作负责生成 IR。XLA_HLO_DEBUG
:启用在启用 _XLA_IR*DEBUG* 时捕获的 *Python* 堆栈帧,以将其传播到 *XLA* *HLO* 元数据。XLA_SAVE_TENSORS_FILE
:用于在执行期间转储 IR 图的文件的路径。请注意,如果选项保持启用状态并且 *PyTorch* 程序长时间运行,则该文件可能会变得非常大。这些图会附加到文件中,因此要从一次运行到下一次运行保持干净,应显式删除该文件。XLA_SAVE_TENSORS_FMT
:_XLA_SAVE_TENSORS*FILE* 文件中存储的图的格式。可以是text
(默认值)、dot
(*Graphviz* 格式)或hlo
。XLA_FLAGS=--xla_dump_to
:如果设置为=/tmp/dir_name
,则 XLA 编译器将在每次编译时转储未优化和优化的 HLO。XLA_METRICS_FILE
:如果设置,则为本地文件的路径,将在每一步保存内部指标。如果已存在,指标将附加到文件中。XLA_SAVE_HLO_FILE
:如果设置,则为本地文件的路径,在发生编译/执行错误时,将保存有问题的 HLO 图。XLA_SYNC_WAIT
:强制 XLA 张量同步操作等待其完成,然后再进行下一步。XLA_USE_EAGER_DEBUG_MODE
:强制 XLA 张量立即执行,这意味着逐个编译和执行 torch 操作。这有助于绕过较长的编译时间,但总体步骤时间会慢很多,并且内存使用量会更高,因为所有编译器优化都将被跳过。XLA_USE_BF16
:如果设置为 1,则在发送到 TPU 设备时,将所有 PyTorch Float 值转换为 BiFloat16。请注意,当使用XLA_USE_BF16=1
时,张量运算将以降低的精度完成,因此如果随着时间的推移累积,张量将不准确。例如# In reduced bfloat16 precision >>> torch.tensor(4096, dtype=torch.bfloat16) + torch.tensor(1, dtype=torch.bfloat16) tensor(4096., dtype=torch.bfloat16) # Whereas in full float32 precision >>> torch.tensor(4096) + torch.tensor(1) tensor(4097)
因此,要获得准确的指标(例如,多步的平均损失值),请使用手动混合精度,其中指标保持在 FP32 中。
XLA_USE_F16
:如果设置为 1,则在发送到支持它们的设备时,将所有 PyTorch Float 值转换为 Float16(PyTorch Half 类型)。TF_CPP_LOG_THREAD_ID
:如果设置为 1,则 TF 日志将显示线程 ID,有助于调试多线程进程。TF_CPP_VMODULE
:用于 TF VLOG 的环境变量,采用TF_CPP_VMODULE=name=value,...
的形式。请注意,对于 VLOG,您必须设置TF_CPP_MIN_LOG_LEVEL=0
。TF_CPP_MIN_LOG_LEVEL
:要打印的消息级别。TF_CPP_MIN_LOG_LEVEL=0
将打开 INFO 日志记录,TF_CPP_MIN_LOG_LEVEL=1
将打开 WARNING 日志记录,依此类推。我们的 PyTorch/XLATF_VLOG
默认使用tensorflow::INFO
级别,因此要查看 VLOG,请设置TF_CPP_MIN_LOG_LEVEL=0
。XLA_DUMP_HLO_GRAPH
:如果在编译或执行错误的情况下设置为=1
,则违规的 HLO 图将作为xla_util.cc
引发的运行时错误的一部分被转储。
常见的调试环境变量组合¶
以 IR 格式记录图形执行
XLA_IR_DEBUG=1 XLA_HLO_DEBUG=1 XLA_SAVE_TENSORS_FMT="text" XLA_SAVE_TENSORS_FILE="/tmp/save1.ir"
以 HLO 格式记录图形执行
XLA_IR_DEBUG=1 XLA_HLO_DEBUG=1 XLA_SAVE_TENSORS_FMT="hlo" XLA_SAVE_TENSORS_FILE="/tmp/save1.hlo"
显示运行时和图形编译/执行的调试 VLOG
TF_CPP_MIN_LOG_LEVEL=0 TF_CPP_VMODULE="xla_graph_executor=5,pjrt_computation_client=3"
重现 PyTorch/XLA CI/CD 单元测试失败。¶
您可能会看到 PR 的一些测试失败,例如
To execute this test, run the following from the base repo dir:
PYTORCH_TEST_WITH_SLOW=1 python ../test/test_torch.py -k test_put_xla_uint8
直接在命令行中运行此命令不起作用。您需要将环境变量 TORCH_TEST_DEVICES
设置为您本地的 pytorch/xla/test/pytorch_test_base.py
。例如
TORCH_TEST_DEVICES=/path/to/pytorch/xla/test/pytorch_test_base.py PYTORCH_TEST_WITH_SLOW=1 python ../test/test_torch.py -k test_put_xla_uint8
应该可以工作。
PJRT 运行时¶
PyTorch/XLA 已从基于 TensorFlow 的 XRT 运行时迁移到 PJRT 运行时,后者由 JAX 使用。
如果您在使用 PJRT 时遇到错误,请在 GitHub 上提交一个带有 runtime
标签的问题。
PyTorch/XLA r2.1 中的新功能:
PJRT 在 PyTorch/XLA r2.1 中稳定!
公共运行时 API 已从
torch_xla.experimental.pjrt
移动到torch_xla.runtime
。pjrt://
初始化方法已重命名为xla://
,并由torch_xla.distributed.xla_backend
注册。以前的
torch_xla.experimental.*
名称在此版本中仍然可用,以确保兼容性。
当使用
init_method='xla://'
时,现在支持torchrun
。通过 PJRT C API 为 XPU 和 Neuron 提供的新插件。
PyTorch/XLA r2.0 中的新功能:
如果您没有传入任何其他运行时配置,则默认情况下将配置 PJRT。如果您继续设置 XRT 配置(
XRT_TPU_CONFIG
),则此更改不会产生影响libtpu
中新的 TPU 运行时实现将性能提高了高达 30%。新的
xm.rendezvous
实现可以扩展到数千个 TPU 核心[实验性] 对 TPU v2 和 v3 的
torch.distributed
支持,包括pjrt://
init_method
TL;DR¶
要使用 PJRT 预览运行时,请将
PJRT_DEVICE
环境变量设置为CPU
、TPU
或CUDA
在 XRT 中,所有分布式工作负载都是多进程的,每个设备一个进程。在 PJRT 中的 TPU v2 和 v3 上,工作负载是多进程和多线程的(4 个进程,每个进程 2 个线程),因此您的工作负载应该是线程安全的。有关更多信息,请参阅TPU v2/v3 上的多线程和API 指南的多处理部分。要记住的主要区别
要以线程安全的方式初始化模型,请在初始化后跨副本广播参数(
torch_xla.experimental.pjrt.broadcast_master_param
)或从公共检查点加载每个副本的参数。对于其他随机数生成,请尽可能使用
torch.Generator
。全局torch
RNG *不是*线程安全的,即使您在副本之间设置了相同的torch.manual_seed
。要使用
torch.distributed
,请导入torch_xla.experimental.pjrt_backend
并使用xla://
init_method
。这些步骤对于 GPU 和 TPU v4 是可选的。
从 XRT 到 PJRT 的示例差异
import os
import torch
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.optim as optim
import torch.distributed as dist
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_backend
import torch_xla.distributed.xla_multiprocessing as xmp
+import torch_xla.runtime as xr
def _mp_fn(index):
device = xm.xla_device()
- dist.init_process_group('xla', rank=xm.get_ordinal(), world_size=xm.xrt_world_size())
+ dist.init_process_group('xla', init_method='xla://')
torch.manual_seed(42)
model = nn.Linear(128, 10).to(device)
+ # Optional for TPU v4 and GPU
+ xm.broadcast_master_param(model)
model = DDP(model, gradient_as_bucket_view=True)
loss_fn = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=.001)
for i in range(10):
data, target = torch.randn((128, 128), device=device), torch.randn((128, 10), device=device)
optimizer.zero_grad()
output = model(data)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
xm.mark_step()
# Print mean parameters so we can confirm they're the same across replicas
print([p.mean() for p in model.parameters()])
if __name__ == '__main__':
- os.environ['XRT_TPU_CONFIG'] = 'localservice;0;localhost:51011'
- os.environ['MASTER_ADDR'] = 'localhost'
- os.environ['MASTER_PORT'] = '12355'
+ # Recommended: set PJRT_DEVICE to your local device type
+ os.environ['PJRT_DEVICE'] = 'TPU'
xmp.spawn(_mp_fn)
好处¶
简单的运行时配置:只需将
PJRT_DEVICE
设置为TPU
、CPU
或CUDA
并开始使用 XLA!或者,让 PJRT 根据您的环境自动选择设备。改进的性能:减少 gRPC 的开销意味着更快的端到端执行。在 TorchBench 2.0 上,我们观察到 TPU v4 上的训练时间缩短了 >35%。
轻松的 Pod 执行:只需将您的代码复制到每个 TPU 工作器,并使用
gcloud compute tpus tpuvm ssh --worker=all
同时执行它们。更好的扩展性:消除了XRT 对参数大小的限制,并支持多达 2048 个 TPU 芯片。
快速入门¶
要开始在 PyTorch/XLA 中使用 PJRT,您只需设置 PJRT_DEVICE
环境变量即可。如果您使用的是 TPU v2 或 v3,请继续阅读以了解 TPU v2 和 v3 与 v4 之间的区别。
CPU¶
在安装了 PyTorch/XLA 的任何机器上,您都可以在 CPU 上运行我们的 MNIST 示例,如下所示
PJRT_DEVICE=CPU python3 xla/test/test_train_mp_mnist.py --fake_data
TPU¶
要使用安装了 PyTorch/XLA r2.0 的新 TPU
gcloud alpha compute tpus tpu-vm create $USER-pjrt --accelerator-type=v4-8 --version=tpu-vm-v4-pt-2.0 --zone=us-central2-b --project=$PROJECT
在 v4-8 上,您可以运行我们的 ResNet50 示例,如下所示
git clone --depth=1 --branch r2.0 https://github.com/pytorch/xla.git
PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=256 --num_epochs=1
默认情况下,PJRT 将使用所有 TPU 芯片。要仅使用一个 TPU 芯片,请配置 TPU_PROCESS_BOUNDS
和 TPU_VISIBLE_CHIPS
TPU_PROCESS_BOUNDS=1,1,1 TPU_VISIBLE_CHIPS=0 PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=256 --num_epochs=1
Pod¶
在 TPU Pod 上,使用 gcloud
在每个 TPU 上并行运行您的命令
gcloud alpha compute tpus tpu-vm ssh $USER-pjrt --zone=us-central2-b --project=$PROJECT --worker=all --command="git clone --depth=1 --branch r1.13 https://github.com/pytorch/xla.git"
gcloud alpha compute tpus tpu-vm ssh $USER-pjrt --zone=us-central2-b --project=$PROJECT --worker=all --command="PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=256 --num_epochs=1"
Docker¶
您还可以使用 Docker 在预装了 PyTorch/XLA 的容器中运行您的工作负载
export DOCKER_IMAGE=gcr.io/...
# Optional: authenticate docker if your image is in a private GCP repository
gcloud compute tpus tpu-vm ssh $USER-pjrt --zone=us-central2-b --project=$PROJECT --worker=all --command "sudo gcloud auth configure-docker"
# Run your workload
gcloud compute tpus tpu-vm ssh $USER-pjrt --zone=us-central2-b --project=$PROJECT --worker=all --command "sudo docker run --rm --privileged --net=host -e PJRT_DEVICE=TPU $DOCKER_IMAGE python pytorch/xla/test/test_train_mp_imagenet.py --fake_data"
请注意,docker run
需要对主机具有特权访问权限(--privileged
),才能将 TPU 设备暴露给容器。目前,仅在主机网络 --net=host
中支持 TPU Pod 上的 Docker。有关更多信息,请参阅Cloud TPU 文档。
GPU¶
单节点 GPU 训练¶
要将 GPU 与 PJRT 一起使用,只需设置 PJRT_DEVICE=CUDA
并将 GPU_NUM_DEVICES
配置为主机上的设备数量。例如
PJRT_DEVICE=CUDA GPU_NUM_DEVICES=4 python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=128 --num_epochs=1
您还可以使用 torchrun
启动单节点多 GPU 训练。例如,
PJRT_DEVICE=CUDA torchrun --nnodes 1 --nproc-per-node ${NUM_GPU_DEVICES} xla/test/test_train_mp_imagenet.py --fake_data --pjrt_distributed --batch_size=128 --num_epochs=1
在上面的示例中,--nnodes
表示要使用多少台机器(物理机或虚拟机)(由于我们进行单节点训练,因此为 1)。 --nproc-per-node
表示要使用多少个 GPU 设备。
多节点 GPU 训练¶
**请注意,此功能仅适用于 cuda 12+**。与 PyTorch 使用多节点训练的方式类似,您可以运行如下命令
PJRT_DEVICE=CUDA torchrun \
--nnodes=${NUMBER_GPU_VM} \
--node_rank=${CURRENT_NODE_RANK} \
--nproc_per_node=${NUMBER_LOCAL_GPU_DEVICES} \
--rdzv_endpoint=<internal_ip_address:port> multinode_training.py
--nnodes
:要使用多少台 GPU 机器。--node_rank
:当前 GPU 机器的索引。该值可以是 0、1、…、${NUMBER_GPU_VM}-1。--nproc_per_node
:要在当前机器上使用的 GPU 设备数量。–rdzv_endpoint:节点排名 (node_rank) 为 0 的 GPU 设备的端点,格式为 host:port`。其中,“host
将 是 内部 IP 地址。
“port”` 可以是设备上任何可用的端口。对于单节点训练/推理,可以省略此参数。
例如,如果您想在 2 台 GPU 设备(machine_0 和 machine_1)上进行训练,请在第一台 GPU 设备 machine_0 上运行
# PJRT_DEVICE=CUDA torchrun \
--nnodes=2 \
--node_rank=0 \
--nproc_per_node=4 \
--rdzv_endpoint="<MACHINE_0_INTERNAL_IP_ADDRESS>:12355" pytorch/xla/test/test_train_mp_imagenet.py --fake_data --pjrt_distributed --batch_size=128 --num_epochs=1
在第二台 GPU 设备上运行
# PJRT_DEVICE=CUDA torchrun \
--nnodes=2 \
--node_rank=1 \
--nproc_per_node=4 \
--rdzv_endpoint="<MACHINE_0_INTERNAL_IP_ADDRESS>:12355" pytorch/xla/test/test_train_mp_imagenet.py --fake_data --pjrt_distributed --batch_size=128 --num_epochs=1
上面两条命令的区别在于 --node_rank
,以及如果您想在每台设备上使用不同数量的 GPU 设备,则可能还有 --nproc_per_node
。其余所有参数都相同。有关 torchrun
的更多信息,请参阅此 页面。
与 XRT 的区别¶
尽管在大多数情况下,我们预计 PJRT 和 XRT 从最终用户的角度来看基本上可以互换使用(尤其是在 TPU v4 上),但仍有一些细微的差别需要注意。重要的是,XRT 是围绕 TPU 节点架构设计的,因此它始终会生成一个客户端进程和一个服务器进程,即使在 TPU VM 上也是如此。因此,每批输入都会因序列化和反序列化数据以通过网络发送而增加延迟。
PJRT 直接使用本地设备,没有中间服务器进程。在默认配置中,PJRT 将为每个 TPU 芯片创建一个进程,或者为每个 TPU 主机创建 4 个进程。有关 TPU 架构的更多信息,请参阅 Cloud TPU 文档。
对于受限于开销的工作负载,可以提高性能。
在 XRT 下,服务器进程是唯一与 TPU 设备交互的进程,客户端进程没有对 TPU 设备的直接访问权限。在分析单主机 TPU(例如 v3-8 或 v4-8)时,您通常会看到 8 个设备跟踪(每个 TPU 核心一个)。使用 PJRT,每个进程都有一个芯片,并且该进程的配置文件将仅显示 2 个 TPU 核心。
出于相同的原因,分析在使用 XRT 的 TPU Pod 上不起作用,因为服务器进程独立于用户的模型代码运行。PJRT 没有这个限制,因此可以在 TPU Pod 中分析每个进程 2 个 TPU 核心。
PJRT 仅支持 TPU VM 架构,我们没有计划使用 PJRT 支持 TPU 节点架构。
使用 PJRT,运行时配置要简单得多。运行 TPU Pod 工作负载不需要
xla_dist
。相反,将您的代码复制到每个 TPU 主机([gcloud compute tpus tpu-vm scp](https://cloud.google.com/sdk/gcloud/reference/alpha/compute/tpus/tpu-vm/scp)
),并在每个主机上并行运行代码(例如[gcloud compute tpus tpu-vm ssh --workers=all --command="PJRT_DEVICE=TPU python run.py"](https://cloud.google.com/sdk/gcloud/reference/alpha/compute/tpus/tpu-vm/ssh)
)。xm.rendezvous
已使用 XLA 原生集体通信重新实现,以增强大型 TPU pod 上的稳定性。有关更多详细信息,请参见下文。
TPU v2/v3 上的多线程¶
在 TPU v2 和 v3 上,分布式工作负载始终以多线程方式运行,因为每个 TPU 芯片都将两个 TPU 核心公开为设备,并且一次只能有一个进程打开一个 TPU 芯片。在其默认配置中,xmp.spawn
会自动生成尽可能多的进程(每个 TPU 主机 4 个),并为每个进程创建两个线程(每个 TPU 核心一个)。
注意:在 TPU v4 上,每个 TPU 芯片都表示为一个 PyTorch 设备,因此分布式工作负载将在 4 个进程中运行,每个进程只有一个线程。这与 XRT 的行为相同。
在大多数情况下,这不需要对现有代码进行重大更改。在大多数情况下,您需要进行的主要更改是模型初始化。因为 torch
的全局 RNG 在线程之间共享,所以即使您在每个副本中将 torch.manual_seed
设置为相同的值,结果也会因线程和运行而异。为了在副本之间获得一致的参数,可以使用 torch_xla.experimental.pjrt.broadcast_master_param
将一个副本的参数广播到所有其他副本,或者从一个公共检查点加载每个副本的参数。
对 xm.rendezvous 的更改¶
PyTorch/XLA r2.0 中的新功能
使用 XRT 时,工作器 0 运行网格主服务,所有工作器上的所有进程都通过 gRPC 连接到该服务。在实践中,我们发现在具有数千个芯片的 TPU pod 上运行单个网格主进程是不可靠的,因为到工作器 0 的入站连接数量很多。单个客户端进程超时可能会导致故障,并强制整个工作负载重新启动。
因此,我们使用原生 XLA 集体通信重新实现了 xm.rendezvous
,这在大型 TPU pod 上更加稳定且经过了充分测试。与 XRT 实现相比,这带来了两个新的限制
因为有效负载必须成为 XLA 图的一部分,所以在传输数据之前和之后都会调用
xm.mark_step
。在模型代码中间调用xm.rendezvous
可能会强制进行不必要的编译。因为 XLA 不允许在工作器子集上运行集体操作,所以所有工作器都必须参与
rendezvous
。
如果您需要 xm.rendezvous
的旧行为(即在不更改 XLA 图的情况下传输数据和/或同步工作器子集),请考虑使用 ``torch.distributed.barrier` <https://pytorch.ac.cn/docs/stable/distributed.html#torch.distributed.barrier>`_ 或 ``torch.distributed.all_gather_object` <https://pytorch.ac.cn/docs/stable/distributed.html#torch.distributed.all_gather_object>`_,以及 gloo
进程组。如果您还使用 xla
torch.distributed
后端,可以使用 torch.new_group
创建 gloo
子组。请参阅 PyTorch 文档中的 此示例。请记住以下限制
TPU v2/v3 不完全支持
torch.distributed
。仅实现了xla
后端的少量操作,gloo
在多线程环境中可能无法按预期工作。在我们的实验中,
gloo
无法很好地扩展到数千个 TPU 芯片,因此预计此替代方案的可靠性不如在大型规模上使用 PJRT 的xm.rendezvous
。
PJRT 和 torch.distributed¶
PyTorch/XLA r2.0 中的新功能
在将 PJRT 与 torch.distributed
和 [torch.nn.parallel.DistributedDataParallel](https://github.com/pytorch/xla/blob/master/docs/ddp.md)
一起使用时,我们强烈建议使用新的 xla://
init_method
,它通过查询运行时自动查找副本 ID、世界大小和主 IP。例如
import torch
import torch.distributed as dist
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
from torch_xla.experimental import pjrt
# Required for `xla://` init_method and `xla` backend
import torch_xla.distributed.xla_backend
def _all_gather(index: int):
# No need to pass in `rank` or `world_size`
dist.init_process_group('xla', init_method='xla://')
t = torch.tensor([index], dtype=torch.int32, device=xm.xla_device())
output = [torch.zeros_like(t) for _ in range(dist.get_world_size())]
dist.all_gather(output, t)
xm.mark_step()
print(output)
if __name__ == '__main__':
xmp.spawn(_all_gather)
注意:虽然 TPU v4 上不需要 xla://
init_method,但仍然建议使用。如果您使用 env://
,则必须将 MASTER_ADDR
设置为具有设备 0 的 IP 主机,该主机并非始终是工作器 0。xla://
init_method 会自动找到此 IP。
注意:对于 TPU v2/v3,您仍然需要导入 torch_xla.experimental.pjrt_backend
,因为 torch.distributed
中的 TPU v2/v3 支持仍处于实验阶段。
有关在 PyTorch/XLA 上使用 DistributedDataParallel
的更多信息,请参阅 ``ddp.md` <./ddp.md>`_(关于 TPU V4)。有关同时使用 DDP 和 PJRT 的示例,请在 TPU 上运行以下 示例脚本
PJRT_DEVICE=TPU python xla/test/test_train_mp_mnist.py --ddp --pjrt_distributed --fake_data --num_epochs 1
性能¶
TorchBench 显示,与 XRT 相比,PJRT 在 TPU v4-8 上的各项任务的平均训练时间都有所改善,平均改善超过 35%。收益因任务和模型类型而异,从 0% 到 175% 不等。下图显示了按任务细分的对比情况
新的 TPU 运行时¶
PyTorch/XLA r2.0 中的新功能
PyTorch/XLA r2.0 版本引入了对 PJRT 插件 API 的支持,该 API 用于访问 libtpu
中基于 TFRT 的新 TPU 运行时。现在,当设置 PJRT_DEVICE=TPU
时,这是默认运行时。1.13 中使用的基于 StreamExecutor 的旧版 TPU 运行时在 2.0 版本中仍然可以通过 PJRT_DEVICE=TPU_LEGACY
使用,但将在以后的版本中删除。如果您遇到仅在 TPU
上发生而在 TPU_LEGACY
上未发生的问题,请在 GitHub 上提交问题。
在大多数情况下,我们预计两个运行时的性能相似,但在某些情况下,新运行时的速度可能最高提升 30%。下图显示了按任务细分的对比情况
注意:此图表中显示的改进也包含在 PJRT 与 XRT 的比较中。
PyTorch XLA 中的 TorchDynamo(torch.compile) 集成¶
TorchDynamo 是一个 Python 级 JIT 编译器,旨在让未经修改的 PyTorch 程序运行得更快。它为编译器后端提供了一个简洁的 API 来挂钩,其最大特点是在执行之前动态修改 Python 字节码。在 pytorch/xla 2.0 版本中,PyTorch/XLA 为 TorchDynamo 提供了一个实验性后端,用于推理和训练。
XLA 桥接器的工作方式是,当 Dynamo 识别出模型模式时,它将提供一个 TorchFX 图,PyTorch/XLA 将使用现有的 Lazy Tensor 技术编译 FX 图并返回编译后的函数。
集成¶
目前,通过向 torch.compile
添加 backend='openxla'
参数,可以支持 PyTorch/XLA 和 Dynamo。例如
import torch
import torch_xla.core.xla_model as xm
def add(a, b):
a_xla = a.to(xm.xla_device())
b_xla = b.to(xm.xla_device())
return a_xla + b_xla
compiled_code = torch.compile(add, backend='openxla')
print(compiled_code(torch.randn(10), torch.randn(10)))
推理¶
以下是用 torch.compile
运行 resnet18 的一个小代码示例
import torch
import torchvision
import torch_xla.core.xla_model as xm
def eval_model(loader):
device = xm.xla_device()
xla_resnet18 = torchvision.models.resnet18().to(device)
xla_resnet18.eval()
dynamo_resnet18 = torch.compile(
xla_resnet18, backend='openxla')
for data, _ in loader:
with torch.no_grad():
output = dynamo_resnet18(data)
使用 torch.compile
,您会发现 PyTorch/XLA 仅在初始化期间跟踪一次 resent18 模型,并在每次调用 dynamo_resnet18
时执行已编译的二进制文件,而不是每次都跟踪模型。以下是在 Cloud TPU v4-8 上使用 torch bench 比较 Dynamo 和 Lazy 的推理速度分析
resnet18 | 2.59 resnet50 | 2.64 resnext50_32x4d | 1.91 alexnet | 1.28 mobilenet_v2 | 18.62 mnasnet1_0 | 2.68 vgg16 | 1.33 BERT_pytorch | 7.49 squeezenet1_1 | 2.29 timm_vision_transformer | 3.52 geomean | 3.04
训练¶
PyTorch/XLA 也支持使用 Dynamo 进行训练,但这项功能尚处于实验阶段,我们正与 PyTorch 编译器团队合作迭代实施。以下是一个使用 torch.compile
训练 resnet18 的示例
import torch
import torchvision
import torch_xla.core.xla_model as xm
def train_model(model, data, target, optimizer):
loss_fn = torch.nn.CrossEntropyLoss()
pred = model(data)
loss = loss_fn(pred, target)
loss.backward()
optimizer.step()
return pred
def train_model_main(loader):
device = xm.xla_device()
xla_resnet18 = torchvision.models.resnet18().to(device)
xla_resnet18.train()
dynamo_train_model = torch.compile(
train_model, backend='openxla')
for data, target in loader:
xla_optimizer = optim.SGD(data, lr=0.1, weight_decay=1e-2)
output = dynamo_train_model(xla_resnet18, data, target, xla_optimizer)
我们预计每个训练步骤提取并执行 3 个图,而不是像使用 Lazy 张量那样每个训练步骤提取并执行 1 个图。以下是在 Cloud TPU v4-8 上使用 torch bench 比较 Dynamo 和 Lazy 的训练速度分析。
resnet50 | 1.33 resnet18 | 1.33 BERT_pytorch | 3.07 resnext50_32x4d | 1.43 alexnet | 1.12 mobilenet_v2 | 1.4 mnasnet1_0 | 1.19 vgg16 | 0.81 timm_vision_transformer | 1.87 squeezenet1_1 | 1.41 geomean | 1.41
**注意:**我们对每个模型的前向和后向运行单个步骤,然后收集端到端时间。在现实世界中,我们将在每个训练作业中运行多个步骤,这可以很容易地隐藏执行中的跟踪成本(因为它是异步的)。在这种情况下,Lazy Tensor 的性能会好得多。
功能差距¶
我们想指出一个差距,它阻碍了我们在更大规模的模型上使用 TorchDynamo。
TorchDynamo 会将前向和后向跟踪到不同的图中。对于 PyTorch/XLA,让 XLA 编译器将整个步骤视为一个图来最好地优化速度非常重要。启动每个设备执行还有一个固定的开销,这使得每个训练步骤执行多个图不太理想。
与 Lazy Tensor 相比,这种差距使其在现实世界的训练用例中效率较低,尤其是在训练中跟踪成本可以与执行重叠。
总结¶
TorchDynamo 为编译器后端提供了一种非常有前景的方式,可以向用户隐藏复杂性,并轻松地以图的形式检索建模代码。与 PyTorch/XLA 传统的 Lazy Tensor 提取图的方式相比,TorchDynamo 可以跳过每次迭代的图跟踪,从而提供更好的推理响应时间。
PyTorch/XLA 支持的大多数模型在使用新的 dynamo-xla 桥接器运行推理时都获得了显著的加速。我们的社区正在努力扩展支持的模型集。关于上面提到的训练功能差距,PyTorch/XLA 社区非常高兴在我们即将开展的开发工作中缩小训练差距。团队将继续大力投入 TorchDynamo,并与上游合作,使训练故事更加成熟。
PyTorch XLA 中的全分片数据并行 (FSDP)¶
PyTorch XLA 中的全分片数据并行 (FSDP) 是一种用于在数据并行工作器之间分片模块参数的实用程序。
使用示例
import torch
import torch_xla.core.xla_model as xm
from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP
model = FSDP(my_module)
optim = torch.optim.Adam(model.parameters(), lr=0.0001)
output = model(x, y)
loss = output.sum()
loss.backward()
optim.step()
也可以单独分片各个层,并使用外部包装器处理任何剩余的参数。
注意
XlaFullyShardedDataParallel
类支持 https://arxiv.org/abs/1910.02054 中的 ZeRO-2 优化器(分片梯度和优化器状态)和 ZeRO-3 优化器(分片参数、梯度和优化器状态)。ZeRO-3 优化器应该通过嵌套的 FSDP 来实现,并使用
reshard_after_forward=True
。有关示例,请参阅test/test_train_mp_mnist_fsdp_with_ckpt.py
和test/test_train_mp_imagenet_fsdp.py
。对于无法放入单个 TPU 内存或主机 CPU 内存的大型模型,应该将子模块构造与内部 FSDP 包装交织在一起。有关示例,请参阅 ``FSDPViTModel` <https://github.com/ronghanghu/vit_10b_fsdp_example/blob/master/run_vit_training.py>`_。
提供了一个简单的包装器
checkpoint_module
(基于 https://github.com/pytorch/xla/pull/3524 中的torch_xla.utils.checkpoint.checkpoint
),用于对给定的nn.Module
实例执行 梯度检查点。有关示例,请参阅test/test_train_mp_mnist_fsdp_with_ckpt.py
和test/test_train_mp_imagenet_fsdp.py
。自动包装子模块:除了手动嵌套 FSDP 包装外,还可以指定
auto_wrap_policy
参数来自动使用内部 FSDP 包装子模块。torch_xla.distributed.fsdp.wrap
中的size_based_auto_wrap_policy
是auto_wrap_policy
可调用对象的示例,此策略包装参数数量大于 1 亿的层。torch_xla.distributed.fsdp.wrap
中的transformer_auto_wrap_policy
是针对类似 Transformer 的模型架构的auto_wrap_policy
可调用对象的示例。
例如,要使用内部 FSDP 自动包装所有 torch.nn.Conv2d
子模块,可以使用
from torch_xla.distributed.fsdp.wrap import transformer_auto_wrap_policy
auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={torch.nn.Conv2d})
此外,还可以指定 auto_wrapper_callable
参数来为子模块使用自定义的可调用包装器(默认包装器只是 XlaFullyShardedDataParallel
类本身)。例如,可以使用以下内容对每个自动包装的子模块应用梯度检查点(即激活检查点/重新实例化)。
from torch_xla.distributed.fsdp import checkpoint_module
auto_wrapper_callable = lambda m, *args, **kwargs: XlaFullyShardedDataParallel(
checkpoint_module(m), *args, **kwargs)
在执行优化器步骤时,直接调用
optimizer.step
,而不要调用xm.optimizer_step
。后者会减少跨等级的梯度,而 FSDP(其中参数已经分片)不需要这样做。在训练期间保存模型和优化器检查点时,每个训练进程都需要保存其自己的(分片的)模型和优化器状态字典的检查点(使用
master_only=False
并为xm.save
中的每个等级设置不同的路径)。恢复时,它需要加载对应等级的检查点。还请将
model.get_shard_metadata()
与model.state_dict()
一起保存,如下所示,并使用consolidate_sharded_model_checkpoints
将分片的模型检查点拼接在一起,形成一个完整的模型状态字典。有关示例,请参阅test/test_train_mp_mnist_fsdp_with_ckpt.py
。.. code-block:: python3- ckpt = {
‘model’: model.state_dict(), ‘shard_metadata’: model.get_shard_metadata(), ‘optimizer’: optimizer.state_dict(),
} ckpt_path = f’/tmp/rank-{xm.get_ordinal()}-of-{xm.xrt_world_size()}.pth’ xm.save(ckpt, ckpt_path, master_only=False)
也可以从命令行启动检查点合并脚本,如下所示。.. code-block:: bash
# 通过命令行工具合并保存的检查点 python3 -m torch_xla.distributed.fsdp.consolidate_sharded_ckpts –ckpt_prefix /path/to/your_sharded_checkpoint_files –ckpt_suffix “_rank--of-.pth”
此类的实现很大程度上受到 https://fairscale.readthedocs.io/en/stable/api/nn/fsdp.html 中的 fairscale.nn.FullyShardedDataParallel
的启发,并且大部分遵循其结构。与 fairscale.nn.FullyShardedDataParallel
的最大区别之一是,在 XLA 中,我们没有显式的参数存储,因此我们采用了一种不同的方法来释放 ZeRO-3 的完整参数。
MNIST 和 ImageNet 上的示例训练脚本¶
MNIST:``test/test_train_mp_mnist_fsdp_with_ckpt.py` <https://github.com/pytorch/xla/blob/master/test/test_train_mp_mnist_fsdp_with_ckpt.py>`_(它还测试了检查点合并)
ImageNet:``test/test_train_mp_imagenet_fsdp.py` <https://github.com/pytorch/xla/blob/master/test/test_train_mp_imagenet_fsdp.py>`_
安装¶
FSDP 在 PyTorch/XLA 1.12 版本和更新的 nightly 版本中可用。有关安装指南,请参阅 https://github.com/pytorch/xla#-available-images-and-wheels。
克隆 PyTorch/XLA 存储库¶
git clone --recursive https://github.com/pytorch/pytorch
cd pytorch/
git clone --recursive https://github.com/pytorch/xla.git
cd ~/
在 v3-8 TPU 上训练 MNIST¶
经过 2 个 epoch 的训练,准确率约为 98.9%
python3 ~/pytorch/xla/test/test_train_mp_mnist_fsdp_with_ckpt.py \
--batch_size 16 --drop_last --num_epochs 2 \
--use_nested_fsdp --use_gradient_checkpointing
该脚本会在最后自动测试检查点合并。您也可以通过以下方式手动合并分片的检查点
# consolidate the saved checkpoints via command line tool
python3 -m torch_xla.distributed.fsdp.consolidate_sharded_ckpts \
--ckpt_prefix /tmp/mnist-fsdp/final_ckpt \
--ckpt_suffix "_rank-*-of-*.pth"
在 v3-8 TPU 上使用 ResNet-50 训练 ImageNet¶
经过 100 个 epoch 的训练,准确率约为 75.9%;将 ImageNet-1k 下载到 /datasets/imagenet-1k
python3 ~/pytorch/xla/test/test_train_mp_imagenet_fsdp.py \
--datadir /datasets/imagenet-1k --drop_last \
--model resnet50 --test_set_batch_size 64 --eval_interval 10 \
--lr 0.4 --batch_size 128 --num_warmup_epochs 5 --lr_scheduler_divide_every_n_epochs 30 --lr_scheduler_divisor 10 --num_epochs 100 \
--use_nested_fsdp
您还可以添加 --use_gradient_checkpointing
(需要与 --use_nested_fsdp
或 --auto_wrap_policy
一起使用)以对残差块应用梯度检查点。
TPU pod 上的示例训练脚本(具有 100 亿个参数)¶
为了训练无法放入单个 TPU 的大型模型,在构建整个模型以实现 ZeRO-3 算法时,应该应用自动包装或手动使用内部 FSDP 包装子模块。
有关使用此 XLA FSDP PR 对 Vision Transformer (ViT) 模型进行分片训练的示例,请参阅 https://github.com/ronghanghu/vit_10b_fsdp_example。
如何执行 DistributedDataParallel
¶
本文档介绍了如何在 xla 中使用 torch.nn.parallel.DistributedDataParallel,并进一步描述了它与原生 xla 数据并行方法的区别。
背景/动机¶
长期以来,客户一直要求能够将 PyTorch 的 DistributedDataParallel API 与 xla 一起使用。在这里,我们将其作为一项实验性功能启用。
如何使用 DistributedDataParallel¶
对于从 PyTorch eager 模式切换到 XLA 的用户,以下是将 eager DDP 模型转换为 XLA 模型所需进行的所有更改。我们假设您已经知道如何在单个设备上使用 XLA 。
导入 xla 特定的分布式包
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_backend
初始化 xla 进程组,类似于其他进程组,例如 nccl 和 gloo。
dist.init_process_group("xla", rank=rank, world_size=world_size)
如果需要,使用 xla 特定的 API 获取等级和 world_size。
new_rank = xm.get_ordinal()
world_size = xm.xrt_world_size()
将
gradient_as_bucket_view=True
传递给 DDP 包装器。
ddp_model = DDP(model, gradient_as_bucket_view=True)
最后使用 xla 特定的启动器启动您的模型。
xmp.spawn(demo_fn)
在这里,我们将所有内容放在一起(该示例实际上取自DDP 教程)。您的编码方式与 eager 体验非常相似。只是在单个设备上使用 xla 特定的触控,以及对脚本进行上述五项更改。
import os
import sys
import tempfile
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP
# additional imports for xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_backend
import torch_xla.distributed.xla_multiprocessing as xmp
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
# initialize the xla process group
dist.init_process_group("xla", rank=rank, world_size=world_size)
def cleanup():
dist.destroy_process_group()
class ToyModel(nn.Module):
def __init__(self):
super(ToyModel, self).__init__()
self.net1 = nn.Linear(10, 1000000)
self.relu = nn.ReLU()
self.net2 = nn.Linear(1000000, 5)
def forward(self, x):
return self.net2(self.relu(self.net1(x)))
def demo_basic(rank):
# xla specific APIs to get rank, world_size.
new_rank = xm.get_ordinal()
assert new_rank == rank
world_size = xm.xrt_world_size()
print(f"Running basic DDP example on rank {rank}.")
setup(rank, world_size)
# create model and move it to XLA device
device = xm.xla_device()
model = ToyModel().to(device)
# currently, graident_as_bucket_view is needed to make DDP work for xla
ddp_model = DDP(model, gradient_as_bucket_view=True)
loss_fn = nn.MSELoss()
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
optimizer.zero_grad()
outputs = ddp_model(torch.randn(20, 10).to(device))
labels = torch.randn(20, 5).to(device)
loss_fn(outputs, labels).backward()
optimizer.step()
# xla specific API to execute the graph
xm.mark_step()
cleanup()
def run_demo(demo_fn):
# xla specific launcher
xmp.spawn(demo_fn)
if __name__ == "__main__":
run_demo(demo_basic)
基准测试¶
使用虚假数据的 Resnet50¶
以下结果是使用以下命令收集的:python test/test_train_mp_imagenet.py --fake_data --model=resnet50 --num_epochs=1
在具有 ToT PyTorch 和 PyTorch/XLA 的 TPU VM V3-8 环境中。统计指标是使用此拉取请求中的脚本生成的。速率单位是每秒图像数。
类型 | 平均值 | 中位数 | 第 90 个百分位数 | 标准差 | CV |
xm.optimizer_step | 418.54 | 419.22 | 430.40 | 9.76 | 0.02 |
DDP | 395.97 | 395.54 | 407.13 | 7.60 | 0.02 |
我们用于分布式数据并行的原生方法与 DistributedDataParallel 包装器之间的性能差异为:1 - 395.97 / 418.54 = 5.39%。考虑到 DDP 包装器在跟踪 DDP 运行时引入了额外的开销,因此该结果似乎是合理的。
使用虚假数据的 MNIST¶
以下结果是使用以下命令收集的:python test/test_train_mp_mnist.py --fake_data
在具有 ToT PyTorch 和 PyTorch/XLA 的 TPU VM V3-8 环境中。统计指标是使用此拉取请求中的脚本生成的。速率单位是每秒图像数。
类型 | 平均值 | 中位数 | 第 90 个百分位数 | 标准差 | CV |
xm.optimizer_step | 17864.19 | 20108.96 | 24351.74 | 5866.83 | 0.33 |
DDP | 10701.39 | 11770.00 | 14313.78 | 3102.92 | 0.29 |
我们用于分布式数据并行的原生方法与 DistributedDataParallel 包装器之间的性能差异为:1 - 14313.78 / 24351.74 = 41.22%。在这里,我们比较了第 90 个百分位数,因为数据集很小,前几轮受数据加载的影响很大。这种减速是巨大的,但考虑到模型很小,这是有道理的。额外的 DDP 运行时跟踪开销难以摊销。
使用真实数据的 MNIST¶
以下结果是使用以下命令收集的:python test/test_train_mp_mnist.py --logdir mnist/
在具有 ToT PyTorch 和 PyTorch/XLA 的 TPU VM V3-8 环境中。

我们可以观察到,即使 DDP 包装器最终仍能达到 97.48% 的高准确率,但它的收敛速度比原生 XLA 方法慢。(原生方法达到了 99%。)
免责声明¶
此功能仍处于实验阶段,并且正在积极开发中。请谨慎使用,并随时将任何错误报告给xla github 存储库。对于那些对原生 xla 数据并行方法感兴趣的人,这里是教程。
以下是一些正在调查的已知问题
gradient_as_bucket_view=True
需要强制执行。与
torch.utils.data.DataLoader
一起使用时存在一些问题。test_train_mp_mnist.py
在退出之前使用真实数据崩溃。
如何使用 PyTorch/XLA:GPU 运行¶
PyTorch/XLA 使 PyTorch 用户能够利用 XLA 编译器,该编译器支持 TPU、GPU 和 CPU 等加速器。本文档将介绍在 nvidia GPU 实例上运行 PyTorch/XLA 的基本步骤。
环境设置¶
Docker¶
Pytorch/XLA 目前发布了使用 cuda11.7/8 和 python 3.8 的预构建 docker 镜像和 wheel。我们建议用户创建一个具有相应配置的 docker 容器。有关 docker 镜像和 wheel 的完整列表,请参阅本文档。
sudo docker pull us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_cuda_12.1
sudo apt-get install -y apt-transport-https ca-certificates curl gnupg-agent software-properties-common
distribution=$(. /etc/os-release;echo $ID$VERSION_ID)
curl -s -L https://nvda.org.cn/nvidia-docker/gpgkey | sudo apt-key add -
curl -s -L https://nvda.org.cn/nvidia-docker/$distribution/nvidia-docker.list | sudo tee /etc/apt/sources.list.d/nvidia-docker.list
sudo apt-get update && sudo apt-get install -y nvidia-container-toolkit
sudo systemctl restart docker
sudo docker run --shm-size=16g --net=host --gpus all -it -d us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_cuda_12.1 bin/bash
sudo docker exec -it $(sudo docker ps | awk 'NR==2 { print $1 }') /bin/bash
请注意,您需要重新启动 docker 以使 gpu 设备在 docker 容器中可见。登录 docker 后,您可以使用 nvidia-smi
验证设备是否已正确设置。
(pytorch) root@20ab2c7a2d06:/# nvidia-smi
Thu Dec 8 06:24:29 2022
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 510.47.03 Driver Version: 510.47.03 CUDA Version: 11.6 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 Tesla V100-SXM2... Off | 00000000:00:04.0 Off | 0 |
| N/A 36C P0 38W / 300W | 0MiB / 16384MiB | 1% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=============================================================================|
| No running processes found |
+-----------------------------------------------------------------------------+
检查环境变量¶
确保 PATH
和 LD_LIBRARY_PATH
环境变量考虑了 cuda。请执行 echo $PATH
和 echo $LD_LIBRARY_PATH
进行验证。如果不是,请按照链接进行操作。示例
echo "export PATH=/usr/local/cuda-12.1/bin${PATH:+:${PATH}}" >> ~/.bashrc
echo "export LD_LIBRARY_PATH=/usr/local/cuda-12.1/lib64${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}" >> ~/.bashrc
source ~/.bashrc
Wheel¶
pip3 install torch==2.2.0
pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.2.0-cp38-cp38-manylinux_2_28_x86_64.whl
运行简单模型¶
为了运行以下示例,您需要克隆 pytorch/xla 存储库以访问 imagenet 示例(我们已经在 docker 中克隆了它)。
(pytorch) root@20ab2c7a2d06:/# export GPU_NUM_DEVICES=1 PJRT_DEVICE=CUDA
(pytorch) root@20ab2c7a2d06:/# git clone --recursive https://github.com/pytorch/xla.git
(pytorch) root@20ab2c7a2d06:/# python xla/test/test_train_mp_imagenet.py --fake_data
==> Preparing data..
Epoch 1 train begin 06:12:38
| Training Device=xla:0/0 Epoch=1 Step=0 Loss=6.89059 Rate=2.82 GlobalRate=2.82 Time=06:13:23
| Training Device=xla:0/0 Epoch=1 Step=20 Loss=6.79297 Rate=117.16 GlobalRate=45.84 Time=06:13:36
| Training Device=xla:0/0 Epoch=1 Step=40 Loss=6.43628 Rate=281.16 GlobalRate=80.49 Time=06:13:43
| Training Device=xla:0/0 Epoch=1 Step=60 Loss=5.83108 Rate=346.88 GlobalRate=108.82 Time=06:13:49
| Training Device=xla:0/0 Epoch=1 Step=80 Loss=4.99023 Rate=373.62 GlobalRate=132.43 Time=06:13:56
| Training Device=xla:0/0 Epoch=1 Step=100 Loss=3.92699 Rate=384.33 GlobalRate=152.40 Time=06:14:02
| Training Device=xla:0/0 Epoch=1 Step=120 Loss=2.68816 Rate=388.35 GlobalRate=169.49 Time=06:14:09
AMP(自动混合精度)¶
AMP 在 GPU 训练中非常有用,PyTorch/XLA 重用了 Cuda 的 AMP 规则。您可以查看我们的mnist 示例和imagenet 示例。请注意,我们还使用了修改后的优化器,以避免设备和主机之间的额外同步。
在 GPU 实例上开发 PyTorch/XLA(使用 GPU 支持从源代码构建 PyTorch/XLA)¶
在 GPU VM 中,从开发 docker 镜像创建 docker 容器。例如
sudo docker pull us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/development:3.8_cuda_12.1
sudo apt-get install -y apt-transport-https ca-certificates curl gnupg-agent software-properties-common
distribution=$(. /etc/os-release;echo $ID$VERSION_ID)
curl -s -L https://nvda.org.cn/nvidia-docker/gpgkey | sudo apt-key add -
curl -s -L https://nvda.org.cn/nvidia-docker/$distribution/nvidia-docker.list | sudo tee /etc/apt/sources.list.d/nvidia-docker.list
sudo apt-get update && sudo apt-get install -y nvidia-container-toolkit
sudo systemctl restart docker
sudo docker run --shm-size=16g --net=host --gpus all -it -d us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/development:3.8_cuda_12.1
sudo docker exec -it $(sudo docker ps | awk 'NR==2 { print $1 }') /bin/bash
从源代码构建 PyTorch 和 PyTorch/XLA。
确保 PATH
和 LD_LIBRARY_PATH
环境变量考虑了 cuda。有关更多信息,请参阅以上内容。
git clone https://github.com/pytorch/pytorch.git
cd pytorch
USE_CUDA=1 python setup.py install
git clone https://github.com/pytorch/xla.git
cd xla
XLA_CUDA=1 python setup.py install
验证 PyTorch 和 PyTorch/XLA 是否已成功安装。
如果您可以在运行简单模型部分中成功运行测试,则 PyTorch 和 PyTorch/XLA 应该已成功安装。
PyTorch/XLA SPMD 用户指南¶
在本用户指南中,我们将讨论GSPMD如何集成到 PyTorch/XLA 中,并提供一个设计概述来说明 SPMD 分片注释 API 及其结构是如何工作的。然后,我们提供了一系列参考示例供用户尝试。
什么是 PyTorch/XLA SPMD?¶
GSPMD是一种针对常见 ML 工作负载的自动并行化系统。XLA 编译器将根据用户提供的分片提示,将单设备程序转换为具有适当集合的分区程序。此功能允许开发人员编写 PyTorch 程序,就好像它们位于一个大型设备上一样,而无需任何自定义分片计算操作和/或集合通信来扩展。

*图 1. 两种不同执行策略的比较,(a) 非 SPMD 和 (b) SPMD。*
为了在 PyTorch/XLA 中支持 GSPMD,我们引入了一种新的执行模式。在 GSPMD 之前,PyTorch/XLA 中的执行模式假设有多个模型副本,每个副本都有一个核心(图 1.a)。如上所述,这种执行模式适用于数据并行框架,例如流行的 PyTorch 分布式数据并行 (DDP)或完全分片数据并行 (FSDP),但它也受到副本只能驻留在一个设备核心中执行的限制。PyTorch/XLA SPMD 引入了一种新的执行模式,该模式假设有一个具有多个核心的副本(图 1.b),允许副本跨多个设备核心运行。这种转变解锁了更高级的并行策略,以获得更好的大型模型训练性能。
PyTorch/XLA SPMD 在新的PJRT运行时上可用。要启用 PyTorch/XLA SPMD 执行模式,用户必须调用[use_spmd() API](https://github.com/pytorch/xla/blob/b8b484515a97f74e013dcf38125c44d53a41f011/torch_xla/runtime.py#L214)
。
import torch_xla.runtime as xr
# Enable PyTorch/XLA SPMD execution mode.
xr.use_spmd()
assert xr.is_spmd() == True
重要的是要注意,SPMD 可以替代任何现有的并行机制,包括 DDP 和 FSDP。用户不能混合两种不同的执行模式(SPMD 和非 SPMD),在本指南的后面部分,我们将介绍如何使用 SPMD 注释来执行 DDP 和 FSDP。
此外,此版本的 SPMD 目前仅在 Google Cloud TPU 上进行了测试和优化。GPU 支持和优化将在 2.2 版本中提供。
PyTorch/XLA SPMD 设计概述¶
简单示例和分片注释 API¶
用户可以使用 mark_sharding
API(源代码)注释原生 PyTorch 张量。它将 torch.Tensor
作为输入,并返回 XLAShardedTensor
作为输出。
def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh, partition_spec: Tuple[Union[int, None]]) -> XLAShardedTensor
调用 mark_sharding
API 需要用户定义的逻辑网格和分区规范,并为 XLA 编译器生成分片注释。分片规范附加到 XLATensor。以下是 [RFC 中的一个简单使用示例,用于说明分片注释 API 的工作原理
import numpy as np
import torch
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
import torch_xla.distributed.spmd as xs
from torch_xla.distributed.spmd import Mesh
# Enable XLA SPMD execution mode.
xr.use_spmd()
# Device mesh, this and partition spec as well as the input tensor shape define the individual shard shape.
mesh_shape = (2, 4)
num_devices = xr.global_runtime_device_count()
device_ids = np.array(range(num_devices))
mesh = Mesh(device_ids, mesh_shape, ('x', 'y'))
t = torch.randn(8, 4).to(xm.xla_device())
# Mesh partitioning, each device holds 1/8-th of the input
partition_spec = (0, 1)
m1_sharded = xs.mark_sharding(t, mesh, partition_spec)
assert isinstance(m1_sharded, XLAShardedTensor) == True
我们可以在 PyTorch 程序中注释不同的张量,以启用不同的并行技术,如下面的注释中所述
# Sharding annotate the linear layer weights.
model = SimpleLinear().to(xm.xla_device())
xs.mark_sharding(model.fc1.weight, mesh, partition_spec)
# Training loop
model.train()
for step, (data, target) in enumerate(loader):
# Assumes `loader` returns data, target on XLA device
optimizer.zero_grad()
# Sharding annotate input data, we can shard any input
# dimensions. Sharidng the batch dimension enables
# in data parallelism, sharding the feature dimension enables
# spatial partitioning.
xs.mark_sharding(data, mesh, partition_spec)
ouput = model(data)
loss = loss_fn(output, target)
optimizer.step()
xm.mark_step()
更完整的单元测试用例和集成测试示例可以在 PyTorch/XLA repo 中找到。
网格¶
对于给定的设备集群,物理网格是互连拓扑的表示。
我们基于此拓扑推导出一个逻辑网格,以创建设备子组,这些子组可用于划分模型中张量的不同轴。

我们使用 Mesh API 抽象逻辑网格。逻辑网格的轴可以命名。以下是一个示例
import torch_xla.runtime as xr
from torch_xla.distributed.spmd import Mesh
# Assuming you are running on a TPU host that has 8 devices attached
num_devices = xr.global_runtime_device_count()
# mesh shape will be (4,2) in this example
mesh_shape = (num_devices // 2, 2)
device_ids = np.array(range(num_devices))
# axis_names 'x' nad 'y' are optional
mesh = Mesh(device_ids, mesh_shape, ('x', 'y'))
mesh.get_logical_mesh()
>> array([[0, 1],
[2, 3],
[4, 5],
[6, 7]])
mesh.shape()
>> OrderedDict([('x', 4), ('y', 2)])
通常,SPMD 程序应创建一个网格并将其用于所有分片,以确保切片分配与预期的分片策略一致。通过操作分区规范(如下所述),可以将同一个网格用于不同形状和分片的张量。
混合网格¶
网格很好地抽象了物理设备网格是如何构建的。用户可以使用逻辑网格以任何形状和顺序排列设备。但是,可以根据物理拓扑定义性能更高的网格,尤其是在涉及数据中心网络 (DCN) 跨切片连接时。HybridMesh 创建了一个网格,该网格为此类多切片环境提供了开箱即用的良好性能。它接受 ici_mesh_shape 和 dcn_mesh_shape,它们表示内部和外部网络的逻辑网格形状。
from torch_xla.distributed.spmd import HybridMesh
# This example is assuming 2 slices of v4-8.
# - ici_mesh_shape: shape of the logical mesh for inner connected devices.
# - dcn_mesh_shape: shape of logical mesh for outer connected devices.
ici_mesh_shape = (1, 4, 1) # (data, fsdp, tensor)
dcn_mesh_shape = (2, 1, 1)
mesh = HybridMesh(ici_mesh_shape, dcn_mesh_shape, ('data','fsdp','tensor'))
print(mesh.shape())
>> OrderedDict([('data', 2), ('fsdp', 4), ('tensor', 1)])
分区规范¶
partition_spec 的秩与输入张量相同。每个维度描述了相应的输入张量维度如何在设备网格(由 mesh_shape 逻辑定义)上进行分片。partition_spec
是 device_mesh
维度 index
或 None 的元组。如果相应的网格维度已命名,则索引可以是 int
或 str
。这指定了每个输入秩如何分片(index
到 mesh_shape
)或复制(None
)。
# Provide optional mesh axis names and use them in the partition spec
mesh = Mesh(device_ids, (4, 2), ('data', 'model'))
partition_spec = ('model', 'data')
xs.mark_sharding(input_tensor, mesh, partition_spec)
我们支持原始 GSPMD 论文中描述的所有三种分片类型。例如,可以像这样指定部分复制
# Provide optional mesh axis names and use them in the partition spec
mesh = Mesh(device_ids, (2, 2, 2), ('x', 'y', 'z'))
# evenly shard across x and z and replicate among y
partition_spec = ('x', 'z') # equivalent to ('x', None, 'z')
xs.mark_sharding(input_tensor, mesh, partition_spec)
分区规范允许为不同的张量形状和所需的分片策略重复使用同一个网格。以下示例使用 3D 网格对此进行了演示
# Create a 3-D mesh of 8 devices with logical dimensions replica, fsdp, and
# tensor
mesh = Mesh(device_ids, (2, 2, 2), ('replica', 'fsdp', 'tensor'))
# A 2D tensor can be sharded along the fsdp and tensor axes and replicated
# along the replica axis by omitting `replica` from the partition spec.
two_d_partially_replicated = torch.randn(64, 64, device='xla')
xs.mark_sharding(two_d_partially_replicated, mesh, ('fsdp', 'tensor'))
# A 2D tensor can be sharded across all dimensions by combining, for example,
# the replica and fsdp mesh axes using a tuple
two_d_fully_sharded = torch.randn(64, 64, device='xla')
xs.mark_sharding(two_d_fully_sharded, mesh, (('replica', 'fsdp'), 'tensor'))
# A 4D tensor can be sharded along up to three of its axes using the 3D mesh
four_d = torch.randn(64, 64, 64, 64, device='xla')
xs.mark_sharding(four_d, ('replica', 'fsdp', None, 'tensor'))
XLAShardedTensor¶
XLAShardedTensor
[RFC] 的主要用例是用分片规范注释原生 torch.tensor
(在单个设备上)。注释会立即进行,但张量的实际分片会被延迟,因为计算是惰性执行的,除了输入张量会立即分片。一旦张量被注释并封装在 XLAShardedTensor
中,它就可以作为 torch.Tensor
传递给现有的 PyTorch 运算符和 nn.Module
层。这对于确保相同的 PyTorch 层和张量运算符可以与 XLAShardedTensor
堆叠在一起非常重要。这意味着用户无需为分片计算重写现有的运算符和模型代码。也就是说,XLAShardedTensor
将满足以下要求
XLAShardedTensor
是torch.Tensor
的子类,可直接与原生 torch 运算符和module.layers
一起使用。我们使用__torch_dispatch__
将XLAShardedTensor
发送到 XLA 后端。PyTorch/XLA 检索附加的分片注释以跟踪图并调用 XLA SPMDPartitioner。在内部,
XLAShardedTensor
(及其 global_tensor 输入)由XLATensor
支持,该数据结构持有对分片设备数据的引用。惰性执行后,当主机请求时(例如,打印全局张量的值),分片张量可以收集并物化回主机作为 global_tensor。
本地分片的句柄严格在惰性执行后实现。
XLAShardedTensor
公开 local_shards 以将可寻址设备上的本地分片作为List[[XLAShard](https://github.com/pytorch/xla/blob/4e8e5511555073ce8b6d1a436bf808c9333dcac6/torch_xla/distributed/spmd/xla_sharded_tensor.py#L12)]
返回。
目前还在努力将 XLAShardedTensor
集成到 DistributedTensor
API 中,以支持 XLA 后端 [RFC]。
DTensor 集成¶
PyTorch 在 2.1 中原型发布了 DTensor。我们正在将 PyTorch/XLA SPMD 集成到 DTensor API 中 RFC。我们有一个用于 distribute_tensor
的概念验证集成,它调用 mark_sharding
注释 API 使用 XLA 对张量及其计算进行分片
import torch
from torch.distributed import DeviceMesh, Shard, distribute_tensor
# distribute_tensor now works with `xla` backend using PyTorch/XLA SPMD.
mesh = DeviceMesh("xla", list(range(world_size)))
big_tensor = torch.randn(100000, 88)
my_dtensor = distribute_tensor(big_tensor, mesh, [Shard(0)])
此功能尚处于实验阶段,请继续关注即将发布的版本中的更多更新、示例和教程。
分片感知主机到设备数据加载¶
PyTorch/XLA SPMD 采用单设备程序,分片并并行执行。SPMD 执行需要使用原生 PyTorch DataLoader,它将数据从主机同步传输到 XLA 设备。这会在每一步的输入数据传输过程中阻塞训练。为了提高原生数据加载性能,当传递可选的 kwarg _input_sharding_ 时,我们使 PyTorch/XLA ParallelLoader 直接支持输入分片(src)。
# MpDeviceLoader returns ParallelLoader.per_device_loader as iterator
train_loader = pl.MpDeviceLoader(
train_loader, # wraps PyTorch DataLoader
device,
# optional input_sharding field
input_sharding=xs.ShardingSpec(input_mesh, (0, 1, 2, 3)))
分布式检查点¶
PyTorch/XLA SPMD 通过专用的 Planner
实例与 torch.distributed.checkpoint 库兼容。用户可以通过此通用接口同步保存和加载检查点。
SPMDSavePlanner 和 SPMDLoadPlanner(src)类使 save
和 load
函数可以直接对 XLAShardedTensor
的分片进行操作,从而在 SPMD 训练中实现分布式检查点的所有优势。
以下是同步分布式检查点 API 的演示
import torch.distributed.checkpoint as dist_cp
import torch_xla.experimental.distributed_checkpoint as xc
# Saving a state_dict
state_dict = {
"model": model.state_dict(),
"optim": optim.state_dict(),
}
dist_cp.save(
state_dict=state_dict,
storage_writer=dist_cp.FileSystemWriter(CHECKPOINT_DIR),
planner=xc.SPMDSavePlanner(),
)
...
# Loading the model's state_dict from the checkpoint. The model should
# already be on the XLA device and have the desired sharding applied.
state_dict = {
"model": model.state_dict(),
}
dist_cp.load(
state_dict=state_dict,
storage_reader=dist_cp.FileSystemReader(CHECKPOINT_DIR),
planner=xc.SPMDLoadPlanner(),
)
model.load_state_dict(state_dict["model"])
CheckpointManager¶
实验性的 CheckpointManager 接口在 torch.distributed.checkpoint
函数之上提供了一个更高级别的 API,以实现一些关键功能
**托管检查点**:
CheckpointManager
获取的每个检查点都由获取该检查点的步骤标识。所有跟踪的步骤都可以通过CheckpointManager.all_steps
方法访问,并且可以使用CheckpointManager.restore
还原任何跟踪的步骤。**异步检查点**:通过
CheckpointManager.save_async
API 获取的检查点会异步写入持久存储,以在检查点期间解除对训练的阻塞。输入分片 state_dict 首先移动到 CPU,然后将检查点分派给后台线程。**抢占时的自动检查点**:在 Cloud TPU 上,可以检测到抢占并在进程终止之前获取检查点。要使用此功能,请确保您的 TPU 通过启用了 自动检查点 的 QueuedResource 进行配置,并确保在构建 CheckpointManager 时设置了
chkpt_on_preemption
参数(默认情况下启用此选项)。**FSSpec 支持**:
CheckpointManager
使用 fsspec 存储后端直接对任何与 fsspec 兼容的文件系统(包括 GCS)进行检查点。
以下是 CheckpointManager 的用法示例
from torch_xla.experimental.distributed_checkpoint import CheckpointManager
# Create a CheckpointManager to checkpoint every 10 steps into GCS.
chkpt_mgr = CheckpointManager('gs://my-bucket/my-experiment', 10)
# Select a checkpoint to restore from, and restore if applicable
tracked_steps = chkpt_mgr.all_steps()
if tracked_steps:
# Choose the highest step
best_step = max(tracked_steps)
state_dict = {'model': model.state_dict()}
chkpt_mgr.restore(best_step, state_dict)
model.load_state_dict(state_dict['model'])
# Call `save` or `save_async` every step within the train loop. These methods
# return True when a checkpoint is taken.
for step, data in enumerate(dataloader):
...
state_dict = {'model': model.state_dict(), 'optim': optim.state_dict()}
if chkpt_mgr.save_async(step, state_dict):
print(f'Checkpoint taken at step {step}')
进程组¶
要使用 torch.distributed
API(例如分布式检查点),需要一个进程组。在 SPMD 模式下,不支持 xla
后端,因为编译器负责所有集合操作。
相反,必须使用 CPU 进程组(例如 gloo
)。在 TPU 上,仍然支持 xla://
init_method 来发现主 IP、全局世界大小和主机排名。以下是一个初始化示例
import torch.distributed as dist
# Import to register the `xla://` init_method
import torch_xla.distributed.xla_backend
import torch_xla.runtime as xr
xr.use_spmd()
# The `xla://` init_method will automatically discover master worker IP, rank,
# and global world size without requiring environment configuration on TPUs.
dist.init_process_group('gloo', init_method='xla://')
虚拟设备优化¶
PyTorch/XLA 通常在定义张量后立即将张量数据从主机异步传输到设备。这是为了将数据传输与图跟踪时间重叠。但是,由于 GSPMD 允许用户在定义张量_之后_修改张量分片,因此我们需要一种优化来防止张量数据在主机和设备之间来回传输。我们引入了虚拟设备优化,这是一种将张量数据首先放置在虚拟设备 SPMD:0 上的技术,然后在所有分片决策完成后再上传到物理设备。SPMD 模式下的每个张量数据都放置在虚拟设备 SPMD:0 上。虚拟设备作为 XLA 设备 XLA:0 向用户公开,并在物理设备(如 TPU:0、TPU:1 等)上具有实际分片。
进程数¶
与现有的 DDP 和 FSDP 不同,在 SPMD 模式下,每个加速器主机上始终只运行一个进程。这提供了一个好处,即 PyTorch/XLA 只需要编译每个图一次,并且可以将其用于连接到此主机的所有加速器。
在 TPU Pod 上运行 SPMD¶
如果根据设备数量而不是硬编码常量构建网格和分区规范,则从单个 TPU 主机切换到 TPU Pod 不需要更改代码。要在 TPU Pod 上运行 PyTorch/XLA 工作负载,请参阅 PJRT 指南的Pods 部分。
在 GPU 上运行 SPMD¶
PyTorch/XLA 支持在 NVIDIA GPU(单节点或多节点)上运行 SPMD。训练/推理脚本与用于 TPU 的脚本相同,例如ResNet 脚本。为了使用 SPMD 执行脚本,我们利用torchrun
PJRT_DEVICE=CUDA \
torchrun \
--nnodes=${NUM_GPU_MACHINES} \
--node_rank=${RANK_OF_CURRENT_MACHINE} \
--nproc_per_node=1 \
--rdzv_endpoint="<MACHINE_0_IP_ADDRESS>:<PORT>" \
training_or_inference_script_using_spmd.py
--nnodes
:要使用多少台 GPU 机器。--node_rank
:当前 GPU 机器的索引。该值可以是 0、1、…、${NUMBER_GPU_VM}-1。--nproc_per_node
:由于 SPMD 的要求,该值必须为 1。–rdzv_endpoint:节点等级为 0 的 GPU 机器的端点,格式为host:port`。主机将是内部 IP 地址。``port`可以是机器上任何可用的端口。对于单节点训练/推理,可以省略此参数。
例如,如果要使用 SPMD 在 2 台 GPU 机器上训练 ResNet 模型,则可以在第一台机器上运行以下脚本
XLA_USE_SPMD=1 PJRT_DEVICE=CUDA \
torchrun \
--nnodes=2 \
--node_rank=0 \
--nproc_per_node=1 \
--rdzv_endpoint="<MACHINE_0_INTERNAL_IP_ADDRESS>:12355" \
pytorch/xla/test/spmd/test_train_spmd_imagenet.py --fake_data --batch_size 128
并在第二台机器上运行以下脚本
XLA_USE_SPMD=1 PJRT_DEVICE=CUDA \
torchrun \
--nnodes=2 \
--node_rank=1 \
--nproc_per_node=1 \
--rdzv_endpoint="<MACHINE_0_INTERNAL_IP_ADDRESS>:12355" \
pytorch/xla/test/spmd/test_train_spmd_imagenet.py --fake_data --batch_size 128
有关更多信息,请参阅GPU 上的 SPMD 支持 RFC。
参考示例¶
使用 SPMD 表示数据并行¶
SPMD API 足够通用,可以表示数据并行和模型并行。可以通过注释输入批处理维度以进行分片来简单地实现数据并行。在这里,我们在所有可用设备(N 路)上对批处理维度进行了分片:有两种使用 SPMD 来表达数据并行或批处理分片的方法
num_devices = xr.global_runtime_device_count()
# Assume data is 4d and 0th dimension is the batch dimension
mesh_shape = (num_devices, 1, 1, 1)
input_mesh = xs.Mesh(device_ids, mesh_shape, ('B', 'C', 'W', 'H'))
partition_spec = range(num_devices)
# Shard the batch dimension
xs.mark_sharding(input_tensor, input_mesh, partition_spec)
PyTorch/XLA 的 MpDeviceLoader 支持输入批处理分片,它还在后台将批处理加载到设备
num_devices = xr.global_runtime_device_count()
# Assume data is 4d and 0th dimension is the batch dimension
mesh_shape = (num_devices, 1, 1, 1)
input_mesh = xs.Mesh(device_ids, mesh_shape, ('B', 'C', 'W', 'H'))
partition_spec = range(num_devices)
# Use MpDeviceLoader to load data in background
train_loader = pl.MpDeviceLoader(
train_loader,
device,
input_sharding=xs.ShardingSpec(input_mesh, partition_spec))
我们强烈建议使用第二种方法,因为它应该会产生更好的训练性能。
使用 SPMD 表达 FSDP(完全分片数据并行)¶
PyTorch 的 FSDP 是数据并行 + 在第 0 维分片的模型参数。用户首先需要使用 SPMD 来表达数据并行,如上一节所述。
for name, param in model.named_parameters():
shape = (num_devices,) + (1,) * (len(param.shape) - 1)
mesh = xs.Mesh(device_ids, shape)
xs.mark_sharding(param, mesh, range(len(param.shape)))
使用 SPMD 运行 Resnet50 示例¶
我们提供了一个resnet50的快速示例,其中包含几种不同的 SPMD 分片策略供您试用。您可以首先使用以下命令在不使用 SPMD 的情况下运行它
python test/spmd/test_train_spmd_imagenet.py --fake_data --batch_size 512
并检查吞吐量。之后,您可以使用以下命令启用批处理分片
XLA_USE_SPMD=1 python test/spmd/test_train_spmd_imagenet.py --fake_data --batch_size 2048 --model=resnet50 --sharding=batch
请注意,我使用的批处理大小是原来的 4 倍,因为我是在 TPU v4 上运行它,它连接了 4 个 TPU 设备。您应该会看到吞吐量大约是非 spmd 运行的 4 倍。
SPMD 调试工具¶
我们为 PyTorch/XLA SPMD 用户提供了一个分片放置可视化调试工具
,适用于 TPU/GPU/CPU 的单主机/多主机:您可以使用visualize_tensor_sharding
可视化分片张量,或者可以使用visualize_sharding
可视化共享字符串。以下是在 TPU 单主机 (v4-8) 上使用visualize_tensor_sharding
或visualize_sharding
的两个代码示例
使用
visualize_tensor_sharding
的代码片段和可视化结果
import rich
# Here, mesh is a 2x2 mesh with axes 'x' and 'y'
t = torch.randn(8, 4, device='xla')
xs.mark_sharding(t, mesh, ('x', 'y'))
# A tensor's sharding can be visualized using the `visualize_tensor_sharding` method
from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding
generated_table = visualize_tensor_sharding(t, use_color=False)

使用
visualize_sharding
的代码片段和可视化结果
from torch_xla.distributed.spmd.debugging import visualize_sharding
sharding = '{devices=[2,2]0,1,2,3}'
generated_table = visualize_sharding(sharding, use_color=False)

您可以在 TPU/GPU/CPU 单主机上使用这些示例,并对其进行修改以在多主机上运行。您还可以将其修改为分片样式tiled
、partial_replication
和replicated
。
自动分片¶
我们正在引入一项新的 PyTorch/XLA SPMD 功能,称为自动分片
,RFC。这是r2.3
和nightly
中的一个实验性功能,支持XLA:TPU
和单个 TPUVM 主机。
可以通过以下方式之一启用 PyTorch/XLA 自动分片
设置环境变量
XLA_SPMD_AUTO=1
在代码开头调用 SPMD API
import torch_xla.runtime as xr
xr.use_spmd(auto=True)
使用
auto-policy
和xla
调用pytorch.distributed._tensor.distribute_module
import torch_xla.runtime as xr
from torch.distributed._tensor import DeviceMesh, distribute_module
from torch_xla.distributed.spmd import auto_policy
device_count = xr.global_runtime_device_count()
device_mesh = DeviceMesh("xla", list(range(device_count)))
# Currently, model should be loaded to xla device via distribute_module.
model = MyModule() # nn.module
sharded_model = distribute_module(model, device_mesh, auto_policy)
(可选)可以设置以下选项/环境变量来控制基于 XLA 的自动分片传递的行为
XLA_AUTO_USE_GROUP_SHARDING
:对参数进行组重新分片。默认情况下设置。XLA_AUTO_SPMD_MESH
:用于自动分片的逻辑网格形状。例如,XLA_AUTO_SPMD_MESH=2,2
对应于一个 2x2 的网格,具有 4 个全局设备。如果未设置,将使用num_devices,1
的默认设备网格形状。
通过 SPMD 实现完全分片数据并行¶
通过 SPMD 或 FSDPv2 实现完全分片数据并行是一个实用程序,它在 SPMD 中重新表达了著名的 FSDP 算法。这是一个实验性功能,旨在为用户提供熟悉的界面,让他们享受 SPMD 带来的所有好处。设计文档在此。
在继续之前,请查看SPMD 用户指南。
使用示例
import torch
import torch_xla.core.xla_model as xm
import torch_xla.distributed.spmd as xs
from torch_xla.experimental.spmd_fully_sharded_data_parallel import SpmdFullyShardedDataParallel as FSDPv2
# Define the mesh following common SPMD practice
num_devices = xr.global_runtime_device_count()
mesh_shape = (num_devices, 1)
device_ids = np.array(range(num_devices))
# To be noted, the mesh must have an axis named 'fsdp', which the weights and activations will be sharded on.
mesh = Mesh(device_ids, mesh_shape, ('fsdp', 'model'))
# Shard the input, and assume x is a 2D tensor.
x = xs.mark_sharding(x, mesh, ('fsdp', None))
# As normal FSDP, but an extra mesh is needed.
model = FSDPv2(my_module, mesh)
optim = torch.optim.Adam(model.parameters(), lr=0.0001)
output = model(x, y)
loss = output.sum()
loss.backward()
optim.step()
也可以单独对各个层进行分片,并使用外部包装器处理任何剩余的参数。自动包装功能将在未来版本中提供。
分片输出¶
为了确保 XLA 编译器正确实现 FSDP 算法,我们需要对权重和激活进行分片。这意味着对 forward 方法的输出进行分片。由于 forward 函数的输出可能会有所不同,因此我们提供 shard_output 来对您的模块输出不属于以下类别的情况下的激活进行分片
单个张量
一个张量元组,其中第 0 个元素是激活。
使用示例
def shard_output(output, mesh):
xs.mark_sharding(output.logits, mesh, ('fsdp', None, None))
model = FSDPv2(my_module, mesh, shard_output)
梯度检查点¶
目前,需要在 FSDP 包装器之前将梯度检查点应用于模块。否则,递归循环到子模块最终将导致无限循环。我们将在未来版本中解决此问题。
使用示例
from torch_xla.distributed.fsdp import checkpoint_module
model = FSDPv2(checkpoint_module(my_module), mesh)