PyTorch 在 XLA 设备上¶
PyTorch 使用 torch_xla 包 在 XLA 设备(如 TPU)上运行。本文档介绍了如何在这些设备上运行模型。
创建 XLA 张量¶
PyTorch/XLA 为 PyTorch 添加了一个新的 xla
设备类型。此设备类型的工作方式与其他 PyTorch 设备类型相同。例如,以下是如何创建和打印 XLA 张量
import torch
import torch_xla
import torch_xla.core.xla_model as xm
t = torch.randn(2, 2, device=xm.xla_device())
print(t.device)
print(t)
这段代码应该很熟悉。PyTorch/XLA 使用与普通 PyTorch 相同的接口,并添加了一些功能。导入 torch_xla
初始化 PyTorch/XLA,而 xm.xla_device()
返回当前的 XLA 设备。这可能是 CPU 或 TPU,具体取决于您的环境。
XLA 张量是 PyTorch 张量¶
PyTorch 操作可以在 XLA 张量上执行,就像在 CPU 或 CUDA 张量上一样。
例如,可以将 XLA 张量加在一起
t0 = torch.randn(2, 2, device=xm.xla_device())
t1 = torch.randn(2, 2, device=xm.xla_device())
print(t0 + t1)
或进行矩阵乘法
print(t0.mm(t1))
或与神经网络模块一起使用
l_in = torch.randn(10, device=xm.xla_device())
linear = torch.nn.Linear(10, 20).to(xm.xla_device())
l_out = linear(l_in)
print(l_out)
与其他设备类型一样,XLA 张量只能与同一设备上的其他 XLA 张量一起使用。因此,像这样的代码
l_in = torch.randn(10, device=xm.xla_device())
linear = torch.nn.Linear(10, 20)
l_out = linear(l_in)
print(l_out)
# Input tensor is not an XLA tensor: torch.FloatTensor
将抛出错误,因为 torch.nn.Linear
模块位于 CPU 上。
在 XLA 设备上运行模型¶
构建新的 PyTorch 网络或将现有网络转换为在 XLA 设备上运行,只需要几行 XLA 特定的代码。以下代码片段突出显示了在单个设备和使用 XLA 多处理的多个设备上运行时的这些代码行。
在单个 XLA 设备上运行¶
以下代码片段显示了一个在单个 XLA 设备上训练的网络
import torch_xla.core.xla_model as xm
device = xm.xla_device()
model = MNIST().train().to(device)
loss_fn = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
for data, target in train_loader:
optimizer.zero_grad()
data = data.to(device)
target = target.to(device)
output = model(data)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
xm.mark_step()
此代码片段突出显示了将模型切换到 XLA 上运行是多么容易。模型定义、数据加载器、优化器和训练循环可以在任何设备上运行。唯一 XLA 特定的代码是获取 XLA 设备并标记步骤的几行代码。在每次训练迭代结束时调用 xm.mark_step()
会导致 XLA 执行其当前图并更新模型的参数。有关 XLA 如何创建图和运行操作的更多信息,请参阅 XLA 张量深入解析。
使用多处理在多个 XLA 设备上运行¶
PyTorch/XLA 使得通过在多个 XLA 设备上运行来加速训练变得容易。以下代码片段展示了如何
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp
def _mp_fn(index):
device = xm.xla_device()
mp_device_loader = pl.MpDeviceLoader(train_loader, device)
model = MNIST().train().to(device)
loss_fn = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
for data, target in mp_device_loader:
optimizer.zero_grad()
output = model(data)
loss = loss_fn(output, target)
loss.backward()
xm.optimizer_step(optimizer)
if __name__ == '__main__':
xmp.spawn(_mp_fn, args=())
此多设备代码片段与之前的单设备代码片段之间存在三个差异。让我们逐一讨论。
xmp.spawn()
创建每个运行 XLA 设备的进程。
每个进程只能访问分配给当前进程的设备。例如,在 TPU v4-8 上,将有 4 个进程启动,每个进程将拥有一个 TPU 设备。
请注意,如果您在每个进程上打印
xm.xla_device()
,您将在所有设备上看到xla:0
。这是因为每个进程只能看到一个设备。这并不意味着多进程不起作用。唯一执行的是 TPU v2 和 TPU v3 上的 PJRT 运行时,因为将有#devices/2
个进程,每个进程将有 2 个线程(有关更多详细信息,请查看此 文档)。
MpDeviceLoader
将训练数据加载到每个设备上。
MpDeviceLoader
可以包装在 torch 数据加载器上。它可以将数据预加载到设备上,并将数据加载与设备执行重叠,以提高性能。MpDeviceLoader
还会在每个batches_per_execution
(默认为 1)批次被生成时为您调用xm.mark_step
。
xm.optimizer_step(optimizer)
整合设备之间的梯度,并发出 XLA 设备步长计算。
它基本上是一个
all_reduce_gradients
+optimizer.step()
+mark_step
,并返回被减少的损失。
模型定义、优化器定义和训练循环保持不变。
注意:重要的是要注意,在使用多处理时,用户只能从
xmp.spawn()
的目标函数(或调用堆栈中以xmp.spawn()
为父级的任何函数)中开始检索和访问 XLA 设备。
有关在多个 XLA 设备上使用多处理训练网络的更多信息,请参阅 完整的 multiprocessing 示例。
在 TPU Pod 上运行¶
不同加速器的多主机设置可能会有很大差异。本文档将讨论多主机训练的设备无关部分,并将使用 TPU + PJRT 运行时(目前在 1.13 和 2.x 版本中可用)作为示例。
在开始之前,请查看我们的用户指南 这里,它将解释一些 Google Cloud 基础知识,例如如何使用 gcloud
命令以及如何设置您的项目。您还可以查看 这里,了解所有 Cloud TPU 操作指南。本文档将重点介绍 PyTorch/XLA 设置的视角。
假设您在 train_mnist_xla.py
中拥有上述 mnist 示例。如果它是单主机多设备训练,您将 ssh 到 TPUVM 并运行类似的命令
PJRT_DEVICE=TPU python3 train_mnist_xla.py
现在,为了在 TPU v4-16(有两个主机,每个主机有 4 个 TPU 设备)上运行相同的模型,您需要
确保每个主机都可以访问训练脚本和训练数据。这通常通过使用
gcloud scp
命令或gcloud ssh
命令将训练脚本复制到所有主机来完成。在所有主机上同时运行相同的训练命令。
gcloud alpha compute tpus tpu-vm ssh $USER-pjrt --zone=$ZONE --project=$PROJECT --worker=all --command="PJRT_DEVICE=TPU python3 train_mnist_xla.py"
上面的 gcloud ssh
命令将 ssh 到 TPUVM Pod 中的所有主机,并在同一时间运行相同的命令。
注意:您需要在 TPUVM 虚拟机之外运行上面的
gcloud
命令。
模型代码和训练脚本对于多进程训练和多主机训练是相同的。PyTorch/XLA 和底层基础设施将确保每个设备都了解全局拓扑以及每个设备的本地和全局序号。跨设备通信将在所有设备之间进行,而不是在本地设备之间进行。
有关 PJRT 运行时以及如何在 pod 上运行它的更多详细信息,请参阅此 文档。有关 PyTorch/XLA 和 TPU pod 的更多信息以及在 TPU pod 上运行 resnet50(使用伪数据)的完整指南,请参阅此 指南。
XLA 张量深入解析¶
使用 XLA 张量和设备只需要更改几行代码。但即使 XLA 张量与 CPU 和 CUDA 张量非常相似,它们的内部结构却有所不同。本节将描述 XLA 张量的独特之处。
XLA 张量是惰性的¶
CPU 和 CUDA 张量会立即或急切地执行操作。另一方面,XLA 张量是惰性的。它们将操作记录在一个图中,直到需要结果。这种延迟执行允许 XLA 对其进行优化。例如,多个独立操作的图可能会被融合成一个优化的操作。
惰性执行通常对调用者不可见。PyTorch/XLA 会自动构建图,将它们发送到 XLA 设备,并在 CPU 和 XLA 设备之间复制数据时进行同步。在执行优化器步骤时插入一个屏障会显式地同步 CPU 和 XLA 设备。有关我们惰性张量设计的更多信息,您可以阅读这篇论文。
XLA 张量和 bFloat16¶
PyTorch/XLA 在 TPU 上运行时可以使用bfloat16 数据类型。事实上,PyTorch/XLA 在 TPU 上对浮点类型(torch.float
和 torch.double
)的处理方式不同。此行为由 XLA_USE_BF16
和 XLA_DOWNCAST_BF16
环境变量控制。
默认情况下,
torch.float
和torch.double
在 TPU 上都是torch.float
。如果设置了
XLA_USE_BF16
,那么torch.float
和torch.double
在 TPU 上都是bfloat16
。如果设置了
XLA_DOWNCAST_BF16
,那么torch.float
在 TPU 上是bfloat16
,而torch.double
在 TPU 上是float32
。如果 PyTorch 张量的数据类型为
torch.bfloat16
,它将直接映射到 TPU 的bfloat16
(XLABF16
原语类型)。
开发者需要注意的是,TPU 上的 XLA 张量将始终报告其 PyTorch 数据类型,无论它们实际使用的是什么数据类型。此转换是自动且不透明的。如果 TPU 上的 XLA 张量被移回 CPU,它将从其实际数据类型转换为其 PyTorch 数据类型。根据代码的操作方式,这种由处理单元类型触发的转换可能很重要。
内存布局¶
XLA 张量的内部数据表示对用户来说是不透明的。它们不公开其存储,并且始终看起来是连续的,这与 CPU 和 CUDA 张量不同。这允许 XLA 调整张量的内存布局以获得更好的性能。
将 XLA 张量移入和移出 CPU¶
XLA 张量可以从 CPU 移到 XLA 设备,也可以从 XLA 设备移到 CPU。如果视图被移动,那么它所查看的数据也会被复制到另一个设备,并且视图关系不会被保留。换句话说,一旦数据被复制到另一个设备,它就与之前的设备或其上的任何张量没有关系。同样,根据代码的操作方式,理解和适应这种转换可能很重要。
保存和加载 XLA 张量¶
XLA 张量应该在保存之前移到 CPU,如下面的代码片段所示
import torch
import torch_xla
import torch_xla.core.xla_model as xm
device = xm.xla_device()
t0 = torch.randn(2, 2, device=device)
t1 = torch.randn(2, 2, device=device)
tensors = (t0.cpu(), t1.cpu())
torch.save(tensors, 'tensors.pt')
tensors = torch.load('tensors.pt')
t0 = tensors[0].to(device)
t1 = tensors[1].to(device)
这使您可以将加载的张量放在任何可用的设备上,而不仅仅是它们初始化所在的设备。
根据上面关于将 XLA 张量移到 CPU 的说明,在处理视图时必须小心。建议您在张量被加载并移到目标设备后重新创建视图,而不是保存视图。
提供了一个实用程序 API 来保存数据,它负责先将数据移到 CPU
import torch
import torch_xla
import torch_xla.core.xla_model as xm
xm.save(model.state_dict(), path)
在多个设备的情况下,上述 API 将只保存主设备序号(0)的数据。
如果内存大小有限,而模型参数大小较大,则提供了一个 API 来减少主机上的内存占用。
import torch_xla.utils.serialization as xser
xser.save(model.state_dict(), path)
此 API 将 XLA 张量一次一个地流式传输到 CPU,从而减少了使用的主机内存量,但它需要一个匹配的加载 API 来恢复。
import torch_xla.utils.serialization as xser
state_dict = xser.load(path)
model.load_state_dict(state_dict)
直接保存 XLA 张量是可能的,但不建议这样做。XLA 张量始终加载回它们保存的设备,如果该设备不可用,则加载将失败。PyTorch/XLA 与 PyTorch 的所有部分一样,正在积极开发中,此行为将来可能会发生变化。
进一步阅读¶
其他文档可在 PyTorch/XLA 仓库 中找到。更多在 TPU 上运行网络的示例可在 此处 获取。
PyTorch/XLA API¶
xla_model¶
-
torch_xla.core.xla_model.
xla_device
(n=None, devkind=None)[source]¶ 返回 XLA 设备的给定实例。
- 参数
n (python:int, 可选) – 要返回的特定实例(序数)。如果指定,将返回特定的 XLA 设备实例。否则将返回第一个 devkind 设备。
devkind (string..., 可选) – 如果指定,则为 TPU、GPU 或 CPU 之一。
- 返回值
具有请求实例的 torch.device。
-
torch_xla.core.xla_model.
get_xla_supported_devices
(devkind=None, max_devices=None)[source]¶ 返回给定类型的支持设备列表。
- 参数
devkind (string..., optional) – 如果指定,则为 TPU、GPU 或 CPU 之一(当前未实现“GPU” XLA 设备)。
max_devices (python:int, optional) – 要返回的该类型设备的最大数量。
- 返回值
设备字符串列表。
-
torch_xla.core.xla_model.
xla_device_hw
(device)[source]¶ 返回给定设备的硬件类型。
- 参数
device (string or torch.device) – 将映射到实际设备的 xla 设备。
- 返回值
给定设备的硬件类型 (CPU、TPU、GPU) 的字符串表示形式。
-
torch_xla.core.xla_model.
get_ordinal
(defval=0)[source]¶ 检索当前线程的复制序号。
序号范围从 0 到 xrt_world_size() 减 1。
- 参数
defval (python:int, optional) – 如果没有可用的复制信息,则返回的默认值。对于 PjRt 忽略。默认值:0
- 返回值
当前线程的复制序号。
-
torch_xla.core.xla_model.
get_local_ordinal
(defval=0)[source]¶ 检索当前线程的复制本地序号。
本地序号范围从 0 到本地设备数量减 1。
- 参数
defval (python:int, optional) – 如果没有可用的复制信息,则返回的默认值。对于 PjRt 忽略。默认值:0
- 返回值
当前线程的复制本地序数。
-
torch_xla.core.xla_model.
is_master_ordinal
(local=True)[source]¶ 检查当前进程是否为主序数 (0)。
- 参数
local (bool) – 是否应检查本地或全局主序数。在多主机复制的情况下,只有一个全局主序数(主机 0,设备 0),而有 NUM_HOSTS 个本地主序数。默认值:True
- 返回值
一个布尔值,指示当前进程是否为主序数。
-
torch_xla.core.xla_model.
xrt_world_size
(defval=1)[source]¶ 检索参与复制的设备数量。
- 参数
defval (python:int, optional) – 如果没有可用的复制信息,则返回的默认值。默认值:1
- 返回值
参与复制的设备数量。
-
torch_xla.core.xla_model.
all_reduce
(reduce_type, inputs, scale=1.0, groups=None, cctx=None, pin_layout=True)[source]¶ 对输入张量执行就地归约操作。
- 参数
reduce_type (string) –
xm.REDUCE_SUM
、xm.REDUCE_MUL
、xm.REDUCE_AND
、xm.REDUCE_OR
、xm.REDUCE_MIN
和xm.REDUCE_MAX
之一。inputs – 要执行所有归约操作的单个 torch.Tensor 或 torch.Tensor 列表。
scale (python:float) – 归约后要应用的默认缩放值。默认值:1.0
groups (list, optional) –
一个列表列表,表示 all_reduce() 操作的副本组。例如:[[0, 1, 2, 3], [4, 5, 6, 7]]
定义两个组,一个包含 [0, 1, 2, 3] 副本,另一个包含 [4, 5, 6, 7] 副本。如果为 None,则将只有一个包含所有副本的组。
pin_layout (bool, optional) – 是否为此通信操作固定布局。布局固定可以防止参与通信的每个进程具有略微不同的程序时出现潜在的数据损坏,但它可能会导致某些 xla 编译失败。当您看到类似“HloModule has a mix of layout constrained”的错误消息时,请取消固定布局。
- 返回值
如果传递单个 torch.Tensor,则返回值为一个 torch.Tensor,其中包含减少的值(跨所有副本)。如果传递列表/元组,则此函数对输入张量执行就地全减少操作,并返回列表/元组本身。
-
torch_xla.core.xla_model.
all_gather
(value, dim=0, groups=None, output=None, pin_layout=True)[source]¶ 沿给定维度执行全收集操作。
- 参数
value (torch.Tensor) – 输入张量。
dim (python:int) – 收集维度。默认值:0
groups (list, optional) –
一个列表列表,表示 all_gather() 操作的副本组。例如:[[0, 1, 2, 3], [4, 5, 6, 7]]
定义两个组,一个包含 [0, 1, 2, 3] 副本,另一个包含 [4, 5, 6, 7] 副本。如果为 None,则将只有一个包含所有副本的组。
output (torch.Tensor) – 可选输出张量。
pin_layout (bool, optional) – 是否为此通信操作固定布局。布局固定可以防止参与通信的每个进程具有略微不同的程序时出现潜在的数据损坏,但它可能会导致某些 xla 编译失败。当您看到类似“HloModule has a mix of layout constrained”的错误消息时,请取消固定布局。
- 返回值
一个张量,在
dim
维度上包含来自所有参与副本的值。
-
torch_xla.core.xla_model.
all_to_all
(value, split_dimension, concat_dimension, split_count, groups=None, pin_layout=True)[source]¶ 对输入张量执行 XLA AllToAll() 操作。
参见:https://tensorflowcn.cn/xla/operation_semantics#alltoall
- 参数
value (torch.Tensor) – 输入张量。
split_dimension (python:int) – 应该进行拆分的维度。
concat_dimension (python:int) – 应该进行连接的维度。
split_count (python:int) – 拆分计数。
groups (list, optional) –
一个列表列表,表示 all_reduce() 操作的副本组。例如:[[0, 1, 2, 3], [4, 5, 6, 7]]
定义两个组,一个包含 [0, 1, 2, 3] 副本,另一个包含 [4, 5, 6, 7] 副本。如果为 None,则将只有一个包含所有副本的组。
pin_layout (bool, optional) – 是否为此通信操作固定布局。布局固定可以防止参与通信的每个进程具有略微不同的程序时出现潜在的数据损坏,但它可能会导致某些 xla 编译失败。当您看到类似“HloModule has a mix of layout constrained”的错误消息时,请取消固定布局。
- 返回值
all_to_all() 操作的结果 torch.Tensor。
-
torch_xla.core.xla_model.
add_step_closure
(closure, args=(), run_async=False)[source]¶ 将闭包添加到要在步骤结束时运行的闭包列表中。
在模型训练过程中,经常需要打印/报告(打印到控制台、发布到 TensorBoard 等)信息,这些信息需要检查中间张量的内容。在模型代码的不同位置检查不同张量的内容需要多次执行,通常会导致性能问题。添加一个步骤闭包将确保它在屏障之后运行,此时所有实时张量都将已物化为设备数据。实时张量将包括闭包参数捕获的那些张量。因此,使用 add_step_closure() 将确保即使在排队多个闭包(需要检查多个张量)的情况下,也只会执行一次。步骤闭包将按它们排队的顺序依次运行。请注意,即使使用此 API,执行也会得到优化,但建议每 N 步节流一次打印/报告事件。
- 参数
closure (callable) – 要调用的函数。
args (tuple) – 要传递给闭包的参数。
run_async – 如果为 True,则异步运行闭包。
-
torch_xla.core.xla_model.
wait_device_ops
(devices=[])[source]¶ 等待给定设备上的所有异步操作完成。
- 参数
devices (string..., optional) – 需要等待其异步操作的设备。如果为空,则将等待所有本地设备。
-
torch_xla.core.xla_model.
optimizer_step
(optimizer, barrier=False, optimizer_args={}, groups=None, pin_layout=True)[source]¶ 运行提供的优化器步骤并发出 XLA 设备步骤计算。
- 参数
optimizer (
torch.Optimizer
) – 需要调用其 step() 函数的 torch.Optimizer 实例。将使用 optimizer_args 命名参数调用 step() 函数。barrier (布尔值, 可选) – 是否在该 API 中发出 XLA 张量屏障。如果使用 PyTorch XLA 的 ParallelLoader 或 DataParallel 支持,则不需要此操作,因为屏障将由 XLA 数据加载器迭代器 next() 调用发出。默认值:False
optimizer_args (字典, 可选) – optimizer.step() 调用的命名参数字典。
groups (list, optional) –
一个列表列表,表示 all_reduce() 操作的副本组。例如:[[0, 1, 2, 3], [4, 5, 6, 7]]
定义两个组,一个包含 [0, 1, 2, 3] 副本,另一个包含 [4, 5, 6, 7] 副本。如果为 None,则将只有一个包含所有副本的组。
pin_layout (布尔值, 可选) – 在减少梯度时是否固定布局。有关详细信息,请参阅 xm.all_reduce。
- 返回值
与 optimizer.step() 调用返回的值相同。
-
torch_xla.core.xla_model.
save
(data, file_or_path, master_only=True, global_master=False)[源代码]¶ 将输入数据保存到文件中。
保存的数据将在保存之前传输到 PyTorch CPU 设备,因此后续的 torch.load() 将加载 CPU 数据。在处理视图时必须小心。建议您在将张量加载到目标设备后重新创建视图,而不是保存视图。
- 参数
data – 要保存的输入数据。任何嵌套的 Python 对象组合(列表、元组、集合、字典等)。
file_or_path – 数据保存操作的目标。可以是文件路径或 Python 文件对象。如果 master_only 为
False
,则路径或文件对象必须指向不同的目标,否则来自同一主机的所有写入将相互覆盖。master_only (布尔值, 可选) – 是否只有主设备应该保存数据。如果为 False,则 file_or_path 参数对于参与复制的每个序号应该是一个不同的文件或路径,否则同一主机上的所有副本将写入同一个位置。默认值:True
global_master (布尔值, 可选) – 当
master_only
为True
时,此标志控制每个主机的 master(如果global_master
为False
)是否保存内容,或者只有全局 master(序号 0)保存内容。默认值:False
-
torch_xla.core.xla_model.
rendezvous
(tag, payload=b'', replicas=[])[source]¶ 等待所有网格客户端到达指定的 rendezvous 点。
- 参数
tag (string) – 要加入的 rendezvous 点的名称。
payload (bytes, optional) – 要发送到 rendezvous 点的有效负载。
replicas (list, python:int) – 参与 rendezvous 的副本序号。空表示网格中的所有副本。默认值:[]
- 返回值
所有其他核心交换的有效负载,其中核心序号为 i 的有效负载位于返回元组中的位置 i。
-
torch_xla.core.xla_model.
do_on_ordinals
(target, data=(), ordinals=(0, ))[source]¶ 仅在给定的序号集上运行函数。
- 参数
target (callable) – 要在 ordinals 上运行的函数。
data – target 函数的任何输入数据,其中包含张量。target 函数使用的所有 XLA 张量都必须在此参数中传递。函数使用的所有其他数据都可以像往常一样被 Python 解释器捕获。默认值:()
ordinals (list, python:int) – target 函数应运行的序号列表/集。默认值: (0,)
- 返回值
在运行 target 函数的序号中,函数返回值,否则为 None。
-
torch_xla.core.xla_model.
mesh_reduce
(tag, data, reduce_fn)[source]¶ 执行一个非图客户端网格约简。
- 参数
tag (string) – 要加入的 rendezvous 点的名称。
data – 要约简的数据。 reduce_fn 可调用对象将接收一个列表,其中包含来自所有网格客户端进程(每个核心一个)的相同数据的副本。
reduce_fn (可调用对象) – 一个接收 data 类对象列表并返回缩减结果的函数。
- 返回值
缩减后的值。
-
torch_xla.core.xla_model.
set_rng_state
(seed, device=None)[源代码]¶ 设置随机数生成器状态。
- 参数
seed (python:integer) – 要设置的状态。
device (字符串, 可选) – 需要设置 RNG 状态的设备。如果缺失,将设置默认设备种子。
-
torch_xla.core.xla_model.
get_rng_state
(device=None)[源代码]¶ 获取当前运行的随机数生成器状态。
- 参数
device (字符串, 可选) – 需要检索其 RNG 状态的设备。如果缺失,将设置默认设备种子。
- 返回值
RNG 状态,以整数形式表示。
-
torch_xla.core.xla_model.
get_memory_info
(device)[源代码]¶ 检索设备内存信息。
- 参数
device (字符串) – 请求其内存信息的设备。
- 返回值
一个包含 kb_free(以 KB 为单位的可用内存)和 kb_total(以 KB 为单位的总内存)键的字典。
-
torch_xla.core.functions.
all_reduce
(reduce_type, value, scale=1.0, groups=None)[源代码]¶ 对输入张量执行就地缩减操作。
这与 xm.all_reduce() 相同,但支持自动微分。
- 参数
reduce_type (字符串) –
REDUCE_SUM
、REDUCE_MUL
、REDUCE_AND
、REDUCE_OR
、REDUCE_MIN
和REDUCE_MAX
之一。value (torch.Tensor) – 要执行所有减少操作的值。
scale (python:float) – 归约后要应用的默认缩放值。默认值:1.0
groups (list, optional) –
一个列表列表,表示 all_reduce() 操作的副本组。例如:[[0, 1, 2, 3], [4, 5, 6, 7]]
定义两个组,一个包含 [0, 1, 2, 3] 副本,另一个包含 [4, 5, 6, 7] 副本。如果为 None,则将只有一个包含所有副本的组。
- 返回值
在选定副本中减少的值。
-
torch_xla.core.functions.
all_gather
(value, dim=0)[source]¶ 沿给定维度执行全收集操作。
这与 xm.all_gather() 相同,但支持自动梯度微分。
- 参数
value (torch.Tensor) – 输入张量。
dim (python:int) – 收集维度。默认值:0
- 返回值
一个张量,在
dim
维度上包含来自所有参与副本的值。
-
torch_xla.core.functions.
nms
(boxes, scores, score_threshold, iou_threshold, output_size)[source]¶ 执行非最大抑制操作。
- 参数
boxes (torch.Tensor) – 形状为 [N, 4] 的 torch.Tensor,列出以 (y0, x0, y1, x1) 形式表示的框坐标。
scores (torch.Tensor) – 形状为 [N] 的 torch.Tensor,列出每个框的分数。
score_threshold (torch.Tensor) – 框要被视为有效的最小分数。
iou_threshold (torch.Tensor) – 触发重叠逻辑的最小 IOU(交并比)分数。
output_size (python:int) – 返回索引的最大数量(必须小于或等于 N)。
- 返回值
一个 torch.Tensor 元组,第一个元素是选定的框索引,第二个元素是有效框的数量。
分布式¶
-
class
torch_xla.distributed.parallel_loader.
ParallelLoader
(loader, devices, batchdim=0, batches_per_execution=1, loader_prefetch_size=8, device_prefetch_size=4, host_to_device_transfer_threads=1)[source]¶ 使用后台数据上传包装现有的 PyTorch DataLoader。
- 参数
loader (
torch.utils.data.DataLoader
) – 要包装的 PyTorch DataLoader。devices (torch.device…) – 数据必须发送到的设备列表。由 loader 返回的第 i 个样本将被发送到 devices[i % len(devices)]。
batchdim (python:int, optional) – 保存批次大小的维度。默认值:0
loader_prefetch_size (python:int, optional) – 用于从 loader 读取样本的线程使用的队列的最大容量,由将数据上传到设备的 worker 线程处理。默认值:8
device_prefetch_size (python:int, optional) – 每个设备队列的最大大小,worker 线程将已发送到设备的张量存放在这里。默认值:4
host_to_device_transfer_threads (python:int, optional) – 并行工作以将数据从加载器队列传输到设备队列的线程数。默认值:1
-
torch_xla.distributed.xla_multiprocessing.
spawn
(fn, args=(), nprocs=None, join=True, daemon=False, start_method='spawn')[source]¶ 启用基于多处理的复制。
- 参数
fn (可调用对象) – 用于每个参与复制的设备的函数。该函数将被调用,第一个参数是复制中进程的全局索引,后面是传递给 args 的参数。
args (元组) – fn 的参数。默认值:空元组
nprocs (python:int) – 复制的进程/设备数量。目前,如果指定,可以是 1 或最大设备数量。
join (布尔值) – 是否应该阻塞调用以等待已生成的进程完成。默认值:True
daemon (布尔值) – 是否应该将已生成的进程的 daemon 标志设置为 True(参见 Python 多进程 API)。默认值:False
start_method (字符串) – Python multiprocessing 进程创建方法。默认值:spawn
- 返回值
由 torch.multiprocessing.spawn API 返回的相同对象。如果 nprocs 为 1,则将直接调用 fn 函数,并且 API 将返回 None。
-
类
torch_xla.distributed.xla_multiprocessing.
MpModelWrapper
(model)[源代码]¶ 包装模型以在使用 fork 方法时最大限度地减少主机内存使用量。
此类应与 spawn(…, start_method=’fork’) API 一起使用,以最大限度地减少主机内存的使用。它不是在每个多进程进程上创建模型,从而复制模型的初始主机内存,而是只在全局范围内创建一次模型,然后在 spawn() 目标函数中将其移动到每个设备中。示例
WRAPPED_MODEL = xmp.MpModelWrapper(MyNetwork()) def _mp_fn(index, ...): device = xm.xla_device() model = WRAPPED_MODEL.to(device) ... xmp.spawn(_mp_fn, ..., start_method='fork')
此方法有两个优点。首先,它只使用一个内存页副本来托管原始模型权重,其次,它通过降低进程期间系统内存的负载来序列化将包装的模型移动到每个设备的过程。
-
类
torch_xla.distributed.xla_multiprocessing.
MpSerialExecutor
[源代码]¶ 实用程序,用于在多核进程中以串行方式运行函数。
示例
# At global scope. SERIAL_EXEC = xmp.MpSerialExecutor() def load_dataset(path): return maybe_download_and_load(path) def _mp_fn(index, ...): # Avoid all cores downloading the same data with the serial executor. dataset = SERIAL_EXEC.run(lambda: load_dataset('/tmp/mnist-data')) ... xmp.spawn(_mp_fn, ...)
utils¶
-
类
torch_xla.utils.tf_record_reader.
TfRecordReader
(path, compression='', buffer_size=16777216, transforms=None)[源代码]¶ 读取 TfRecords 或 TfExamples。
- 参数
path (字符串) – 包含 TfRecords 的文件的路径。
compression (字符串, 可选) – 压缩类型。空字符串表示不压缩,否则为
ZLIB
或GZIP
。默认:不压缩。buffer_size (python:int, 可选) – 用于读取 TfRecords 的缓冲区大小。默认:16 * 1024 * 1024
transforms (dict, optional) – 一个字典,其键与 TfExample 标签名称匹配,其值为可调用对象(将被调用以转换匹配的张量数据),或
STR
用于字符串转换。
-
class
torch_xla.utils.utils.
SampleGenerator
(data, sample_count)[source]¶ 迭代器,返回给定输入数据的多个样本。
可以替代 PyTorch DataLoader 来生成合成数据。
- 参数
data – 每次迭代器步骤应返回的数据。
sample_count – 要返回的 data 样本的最大数量。
-
torch_xla.utils.serialization.
save
(data, path, master_only=True, global_master=False)[source]¶ 将输入数据保存到文件中。
保存的数据将在保存之前传输到 PyTorch CPU 设备,因此后续的 torch.load() 将加载 CPU 数据。在处理视图时必须小心。建议您在将张量加载到目标设备后重新创建视图,而不是保存视图。
- 参数
data – 要保存的输入数据。任何嵌套的 Python 对象组合(列表、元组、集合、字典等)。
path – 数据保存操作的目标文件。如果 master_only 为
False
,则路径必须指向不同的目标,否则来自同一主机的所有写入将相互覆盖。master_only (bool, optional) – 是否只有主设备应该保存数据。如果为 False,则 path 参数应为参与复制的每个序号的不同路径,否则同一主机上的所有副本将写入同一位置。默认值:True
global_master (布尔值, 可选) – 当
master_only
为True
时,此标志控制每个主机的 master(如果global_master
为False
)是否保存内容,或者只有全局 master(序号 0)保存内容。默认值:False
-
torch_xla.utils.serialization.
load
(path)[source]¶ 加载之前使用 save() API 保存的数据。
- 参数
path (str) – 传递给 save() API 的路径。
- 返回值
加载的数据。
-
torch_xla.utils.gcsfs.
open
(path, mode='r', encoding=None)[source]¶ 打开 Google Cloud Storage (GCS) 文件以供读取或写入。
- 参数
path (string) – 文件的 GCS 路径。必须为“gs://BUCKET_NAME/PATH”,其中
BUCKET_NAME
是 GCS 存储桶的名称,PATH
是一个 / 分隔的路径。mode (string, optional) – 打开模式,类似于
open()
API。默认值:‘r’encoding (string, optional) – 在文本模式下打开时用于将字节解码为字符串的字符编码。默认值:None
- 返回值
GCS 文件对象。
-
torch_xla.utils.gcsfs.
list
(path)[source]¶ 列出 GCS 存储桶的内容。
- 参数
path (string) – 文件的 GCS 路径。必须为“gs://BUCKET_NAME/PATH”,其中
BUCKET_NAME
是 GCS 存储桶的名称,PATH
是一个 / 分隔的路径。- 返回值
一个
GcsBlob
对象列表。
-
torch_xla.utils.gcsfs.
stat
(path)[source]¶ 获取 GCS 文件的信息。
- 参数
path (string) – 文件的 GCS 路径。必须为“gs://BUCKET_NAME/PATH”,其中
BUCKET_NAME
是 GCS 存储桶的名称,PATH
是一个 / 分隔的路径。- 返回值
一个
GcsBlob
对象。
-
torch_xla.utils.gcsfs.
remove
(path)[source]¶ 删除 GCS 存储桶中的一个 Blob。
- 参数
path (string) – 文件的 GCS 路径。必须为“gs://BUCKET_NAME/PATH”,其中
BUCKET_NAME
是 GCS 存储桶的名称,PATH
是一个 / 分隔的路径。
-
torch_xla.utils.gcsfs.
rmtree
(path)[source]¶ 删除给定路径下的所有 GCS Blob。
- 参数
path (string) –
文件模式或文件夹的 GCS 路径。必须为“gs://BUCKET_NAME/PATH”,其中
BUCKET_NAME
是 GCS 的名称存储桶,而
PATH
是一个用 / 分隔的路径。
-
torch_xla.utils.gcsfs.
read
(path)[source]¶ 读取 GCS Blob 的全部内容。
- 参数
path (string) – 文件的 GCS 路径。必须为“gs://BUCKET_NAME/PATH”,其中
BUCKET_NAME
是 GCS 存储桶的名称,PATH
是一个 / 分隔的路径。- 返回值
存储在 GCS Blob 中的字节。
-
torch_xla.utils.gcsfs.
write
(path, content)[source]¶ 将字符串/字节或文件写入 GCS Blob。
- 参数
path (string) – 文件的 GCS 路径。必须为“gs://BUCKET_NAME/PATH”,其中
BUCKET_NAME
是 GCS 存储桶的名称,PATH
是一个 / 分隔的路径。content (string, bytes or file object) – 要写入
path
的内容。
-
torch_xla.utils.gcsfs.
generic_open
(path, mode='r', encoding=None)[source]¶ 打开文件(GCS 或其他)以进行读取或写入。
- 参数
path (string) –
要打开的文件的路径。如果为 GCS 路径,则必须为“gs://BUCKET_NAME/PATH”,其中
BUCKET_NAME
是 GCS 的名称存储桶,而
PATH
是一个用 / 分隔的路径。mode (string, optional) – 打开模式,类似于
open()
API。默认值:‘r’encoding (string, optional) – 在文本模式下打开时用于将字节解码为字符串的字符编码。默认值:None
- 返回值
打开的文件对象。
-
torch_xla.utils.gcsfs.
generic_read
(path)[source]¶ 读取提供的路径的全部内容。
- 参数
path (string) – 要读取的 GCS 路径或本地路径。
- 返回值
存储在 GCS Blob 或本地文件中的字节。
-
torch_xla.utils.gcsfs.
generic_write
(output_string, path, makedirs=False)[source]¶ 将字符串/字节或文件写入 GCS Blob 或本地磁盘。
根据传入的路径,此 API 可以写入本地或 GCS 文件。检查 path 是否以 'gs://' 前缀开头,否则使用 open。
- 参数
output_string (string) – 要写入输出的字符串。
path (string) – 输出的 GCS 路径或本地路径。
makedirs (bool) – 如果 path 父文件夹不存在,是否创建。默认值:False
-
torch_xla.utils.gcsfs.
is_gcs_path
(path)[source]¶ 检查路径是否为 GCS 路径。
- 参数
path (string) – 要检查的路径。
- 返回值
path 是否为 GCS 路径。
-
class
torch_xla.utils.cached_dataset.
CachedDataset
(data_set, path, max_files_per_folder=1000, compress=True)[source]¶ 通过提供文件缓存来包装现有的 torch.utils.data.Dataset。
CachedDataset 可用于用存储/网络资源来交换处理原始数据集所需的 CPU/RAM 资源。示例
train_dataset = datasets.MNIST( FLAGS.datadir, train=True, download=True, transform=transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])) train_dataset = CachedDataset(train_dataset, FLAGS.dscache_dir)
CachedDataset 将透明地缓存原始 Dataset 样本,因此第一次运行后,每次运行都不会再触发与原始样本处理相关的任何 CPU/RAM 使用。一旦 CachedDataset 被完全缓存,它就可以被导出(即,tar.gz)并在不同的机器上使用。只需解压缩 tar.gz 并将 None 作为原始 Dataset 传递:示例
train_dataset = CachedDataset(None, FLAGS.dscache_dir)
要完全缓存 CachedDataset,只需运行 warmup() API。保存在 GCS 上的 CachedDataset 具有从不同机器使用而无需显式导出的优势。
- 参数
data_set (torch.utils.data.Dataset) – 要缓存的原始 torch.utils.data.Dataset。如果所有输入样本都存储在 path 文件夹中,则可以将其设置为 None。
path (字符串) – 数据集样本应存储/加载的路径。除非所有样本都已存储,否则path 必须可写。 path 可以是 GCS 路径(以 gs:// 为前缀)。
max_files_per_folder (python:int) – 存储在单个文件夹中的最大文件数量。如果 data_set 为 None,则忽略此值并从缓存的元数据中获取。默认值:1000
compress (布尔值) – 是否应压缩保存的样本。压缩节省空间,但会增加压缩/解压缩所需的 CPU 资源。如果 data_set 为 None,则忽略此值并从缓存的元数据中获取。默认值:True
测试¶
故障排除¶
请注意,本节中的信息可能会在 PyTorch/XLA 软件的未来版本中被删除,因为其中许多信息是特定于可能发生变化的给定内部实现的。
为了诊断问题,我们可以使用 PyTorch/XLA 提供的执行指标和计数器。当模型速度缓慢时,首先要检查的是生成指标报告。
指标报告在诊断问题方面非常有用。如果您有指标报告,请尝试将其包含在您发送给我们的错误报告中。
执行自动指标分析¶
我们提供自动分析指标报告并提供摘要的方法。只需使用 PT_XLA_DEBUG=1
运行您的工作负载。一些示例输出将是
pt-xla-profiler: CompileTime too frequent: 21 counts during 11 steps
pt-xla-profiler: TransferFromServerTime too frequent: 11 counts during 11 steps
pt-xla-profiler: Op(s) not lowered: aten::_ctc_loss, aten::_ctc_loss_backward, Please open a GitHub issue with the above op lowering requests.
pt-xla-profiler: CompileTime too frequent: 23 counts during 12 steps
pt-xla-profiler: TransferFromServerTime too frequent: 12 counts during 12 steps
以下部分将解释如何获取和理解更详细的指标报告。
获取指标报告¶
在您的程序中添加以下行以生成报告
import torch_xla.debug.metrics as met
# For short report that only contains a few key metrics.
print(met.short_metrics_report())
# For full report that includes all metrics.
print(met.short_metrics_report())
理解指标报告¶
报告包括以下内容:
我们发出 XLA 编译的次数以及花费在发出编译上的时间。
我们执行的次数以及花费在执行上的时间
我们创建/销毁的设备数据句柄数量等。
此信息以样本百分位数的形式报告。例如
Metric: CompileTime
TotalSamples: 202
Counter: 06m09s401ms746.001us
ValueRate: 778ms572.062us / second
Rate: 0.425201 / second
Percentiles: 1%=001ms32.778us; 5%=001ms61.283us; 10%=001ms79.236us; 20%=001ms110.973us; 50%=001ms228.773us; 80%=001ms339.183us; 90%=001ms434.305us; 95%=002ms921.063us; 99%=21s102ms853.173us
我们还提供计数器,它们是跟踪内部软件状态的命名整型变量。例如
Counter: CachedSyncTensors
Value: 395
在此报告中,任何以 aten::
开头的计数器都表示 XLA 设备和 CPU 之间的上下文切换,这可能是模型代码中潜在的性能优化领域。
计数器有助于了解哪些操作被路由回 PyTorch 的 CPU 引擎。它们使用其 C++ 命名空间进行完全限定
Counter: aten::nonzero
Value: 33
如果您看到除 nonzero
和 _local_scalar_dense
之外的 aten::
操作,通常意味着 PyTorch/XLA 中缺少降低。请随时在 GitHub 问题 上提交功能请求。
已知性能注意事项¶
PyTorch/XLA 在语义上与常规 PyTorch 相似,XLA 张量与 CPU 和 GPU 张量共享完整的张量接口。但是,XLA/硬件中的约束以及延迟评估模型表明某些模式可能会导致性能下降。
如果您的模型性能不佳,请牢记以下注意事项
XLA/TPU 在进行太多重新编译时会降低性能。
XLA 编译很昂贵。PyTorch/XLA 每次遇到新形状时都会自动重新编译图。通常,模型应该在几个步骤内稳定下来,并且您可以在训练的剩余时间内看到巨大的加速。
为了避免重新编译,不仅形状必须保持不变,而且所有主机中跨 XLA 设备的计算也必须保持不变。
可能的来源:
直接或间接使用
nonzero
会引入动态形状;例如,掩码索引base[index]
,其中index
是一个掩码张量。循环在步骤之间具有不同的迭代次数会导致不同的执行图,因此需要重新编译。
解决方案:
张量形状在迭代之间应该相同,或者应该使用少量形状变化。
尽可能将张量填充到固定大小。
某些操作没有对 XLA 的本机转换。
对于这些操作,PyTorch/XLA 会自动转移到 CPU 内存,在 CPU 上进行评估,并将结果转移回 XLA 设备。在训练步骤中进行太多此类操作会导致显着减慢速度。
可能的来源:
item()
操作明确要求评估结果。除非必要,否则不要使用它。
解决方案:
对于大多数操作,我们可以将它们降低到 XLA 以修复它。查看 指标报告部分 以了解缺少的操作,并在 GitHub 上打开功能请求。
即使 PyTorch 张量被识别为标量,也避免使用
tensor.item()
。将其保留为张量,并对其使用张量操作。在适用情况下,使用
torch.where
来替代控制流。例如,在 clip_grad*norm* 中使用item()
的控制流存在问题,会影响性能,因此我们通过调用torch.where
来 修复clip_grad_norm_
,这为我们带来了显著的性能提升。.. code-block:: python… else
device = parameters[0].device total_norm = torch.zeros([], device=device if parameters else None) for p in parameters
param_norm = p.grad.data.norm(norm_type) ** norm_type total_norm.add_(param_norm)
total_norm = (total_norm ** (1. / norm_type))
clip_coef = torch.tensor(max_norm, device=device) / (total_norm + 1e-6) for p in parameters
p.grad.data.mul_(torch.where(clip_coef < 1, clip_coef, torch.tensor(1., device=device)))
``torch_xla.distributed.data_parallel`` 中的迭代器可能会丢弃输入迭代器中的最后几个批次。
这样做是为了确保我们在所有 XLA 设备上执行相同的工作量。
解决方案:
当数据集很小,并且步数太少时,这可能会导致无操作的 epoch。因此,在这些情况下最好使用较小的批次大小。
XLA 张量特性¶
XLA 张量内部是不可见的。 XLA 张量始终显示为连续且没有存储。网络不应该尝试检查 XLA 张量的步长。
在保存 XLA 张量之前,应将其移动到 CPU。 直接保存 XLA 张量会导致它们被加载回保存它们的设备。如果在加载时设备不可用,则加载将失败。在保存 XLA 张量之前将其移动到 CPU,可以让你决定将加载的张量放在哪个设备上。如果你想在没有 XLA 设备的机器上加载张量,则需要这样做。但是,在保存之前将 XLA 张量移动到 CPU 时,应注意,跨设备类型移动张量不会保留视图关系。相反,应在加载张量后根据需要重建视图。
使用 Python 的 copy.copy 复制 XLA 张量会返回一个深拷贝,而不是浅拷贝。 使用 XLA 张量的视图来获取其浅拷贝。
处理共享权重。 模块可以通过将一个模块的参数设置为另一个模块来共享权重。这种模块权重的“绑定”应该在将模块移动到 XLA 设备 **之后** 完成。否则,将在 XLA 设备上创建共享张量的两个独立副本。
更多调试工具¶
我们不希望用户使用本节中的工具来调试他们的模型。但是,当您提交错误报告时,我们可能会要求您提供这些工具,因为它们提供了指标报告中没有的额外信息。
环境变量¶
还有一些环境变量控制PyTorch/XLA软件堆栈的行为。
设置这些变量会导致不同程度的性能下降,因此它们应该只在调试时启用。
XLA_IR_DEBUG
: 启用在创建 IR 节点时捕获Python堆栈跟踪,从而允许了解哪个PyTorch操作负责生成 IR。XLA_HLO_DEBUG
: 启用在_XLA_IR_DEBUG_处于活动状态时捕获的Python堆栈帧,以传播到XLA _HLO_元数据。XLA_SAVE_TENSORS_FILE
: 用于在执行期间转储 IR 图的文件的路径。请注意,如果该选项保持启用状态并且PyTorch程序长时间运行,则该文件可能会变得非常大。图将附加到该文件,因此要从运行到运行获得干净的表格,应显式删除该文件。XLA_SAVE_TENSORS_FMT
: 存储在_XLA_SAVE_TENSORS_FILE_文件中的图的格式。可以是text
(默认值)、dot
(_Graphviz_格式)或hlo
。XLA_METRICS_FILE
: 如果设置,则为在每一步保存内部指标的本地文件的路径。如果已存在,指标将附加到该文件。XLA_SAVE_HLO_FILE
: 如果设置,则为本地文件的路径,在编译/执行错误的情况下,将保存有问题的 HLO 图。XLA_GET_TENSORS_OPBYOP
: 启用纯 _OpByOp_ 分派。PyTorch/XLA软件尝试将许多PyTorch操作融合到一个计算图中,但有时,出于调试目的,或者如果PyTorch代码具有非常动态的性质(在形状或图方面),最好强制执行 _OpByOp_ 模式(每个 IR 节点都被降低到一个单独的 _XLA_计算,并进行链式执行)。如果此环境变量设置为 1,则在“获取张量”操作期间启用 _OpByOp_(PyTorch/XLA 用于从 _TPU_ 设备将中间值取回 _PyTorch_ CPU 张量的操作)。XLA_SYNC_TENSORS_OPBYOP
: 与 _XLA_GET_TENSORS_OPBYOP_ 相同,但用于“同步张量”操作(在步骤结束时使用的操作,以刷新挂起的 IR 计算并将它们具体化为 _TPU_ 设备数据)。XLA_SYNC_WAIT
: 强制 XLA 张量同步操作等待其完成,然后再进行下一步。XLA_USE_BF16
: 如果设置为 1,则在将所有 *PyTorch* *Float* 值发送到 *TPU* 设备时将其转换为 *BiFloat16*。请注意,当使用XLA_USE_BF16=1
时,张量算术将在降低的精度下完成,因此如果随着时间的推移累积,张量将不准确。例如# In reduced bfloat16 precision >>> torch.tensor(4096, dtype=torch.bfloat16) + torch.tensor(1, dtype=torch.bfloat16) tensor(4096., dtype=torch.bfloat16) # Whereas in full float32 precision >>> torch.tensor(4096) + torch.tensor(1) tensor(4097)
因此,要获得准确的指标,例如许多步骤上的平均损失值,请使用手动混合精度,其中指标保持在 FP32 中。
XLA_USE_F16
: 如果设置为 1,则在将所有 *PyTorch* *Float* 值发送到支持它们的设备时,将其转换为 *Float16* (*PyTorch* *Half* 类型)。XLA_USE_32BIT_LONG
: 如果设置为 1,则将 *PyTorch* *Long* 类型映射到 *XLA* 32 位类型。在编写时 TPU 硬件的版本上,64 位整数计算很昂贵,因此设置此标志可能会有所帮助。用户应根据 *PyTorch* *Long* 值在其中的使用情况验证将值截断为 32 位是否为有效操作。TF_CPP_LOG_THREAD_ID
: 如果设置为 1,则 TF 日志将显示线程 ID,这有助于调试多线程进程。TF_CPP_VMODULE
: 用于 TF VLOG 的环境变量,其形式为TF_CPP_VMODULE=name=value,...
。请注意,对于 VLOG,您必须设置TF_CPP_MIN_LOG_LEVEL=0
。对于 PyTorch/XLA,使用类似TF_CPP_VMODULE=tensor=5
的配置将启用日志记录,例如2019-10-03 17:23:56.419040: I 27891 torch_xla/csrc/tensor.cpp:1104] Executing IR graph hash 4211381954965020633 on device TPU:3 done! 2019-10-03 17:23:56.419448: I 27890 torch_xla/csrc/tensor.cpp:1104] Executing IR graph hash 15483856951158150605 on device TPU:5 done! 2019-10-03 17:23:56.419539: I 27896 torch_xla/csrc/tensor.cpp:1104] Executing IR graph hash 4211381954965020633 on device TPU:4 done! ...
TF_CPP_MIN_LOG_LEVEL
: 用于打印消息的级别。TF_CPP_MIN_LOG_LEVEL=0
将打开 INFO 日志记录,TF_CPP_MIN_LOG_LEVEL=1
WARNING 等等。我们的 PyTorch/XLATF_VLOG
默认使用tensorflow::INFO
级别,因此要查看 VLOG,请设置TF_CPP_MIN_LOG_LEVEL=0
。XLA_DUMP_HLO_GRAPH
: 如果设置为=1
,在编译或执行错误的情况下,错误的 HLO 图将作为xla_util.cc
抛出的运行时错误的一部分被转储。
获取堆栈跟踪¶
如果 PyTorch 进程挂起,包含堆栈跟踪以及 GitHub 问题可能会有用。
首先要找出 PyTorch 进程关联的 PID。使用 ps
命令可以找到该信息。它将是一个运行你的主 python 文件的 python 进程。
为了允许 GDB 附加用户进程,应以 root 身份运行以下命令
echo 0 > /proc/sys/kernel/yama/ptrace_scope
上述命令保持活动状态,直到机器重新启动。
然后,给定 PID,可以使用以下命令获取堆栈跟踪
./scripts/dump_stacks.py PID > /tmp/stack-traces.log
使用 debug_run.py 收集调试信息¶
在 scripts/debug_run.py
中提供了一个实用程序,可用于创建包含调试 PyTorch/XLA 执行所需信息的 tar.gz
存档。
示例
./scripts/debug_run.py --outfile /tmp/debug_run.tar.gz -- python -u SCRIPT [ARGS...]
建议使用 python -u
标志来禁用缓冲,以便捕获的日志正确交错(否则 STDOUT 将在所有 STDERR 之后呈现)。
上面的命令行示例将在文件系统上留下包含存档信息的临时文件夹。使用 --tidy
标志在退出时将其删除
./scripts/debug_run.py --tidy --outfile /tmp/debug_run.tar.gz -- python -u SCRIPT [ARGS...]
然后,在必要时将 debug_run.tar.gz
文件附加到错误报告。
由于脚本将收集大量数据,因此通常应让它运行不超过一百步左右。
如果 SCRIPT 有参数来控制步数,则应使用这些参数,否则按 CTRL^C
将中断运行。
还建议以单核模式运行,以最大程度地减少数据量。在调试执行问题时,也强烈建议以单核模式运行。
常见问题¶
Missing XLA configuration
错误消息:如果您使用 TPU,则需要设置XRT_TPU_CONFIG
。如果您使用 GPU,则设置GPU_NUM_DEVICES=N
用于N
个 GPU。如果您使用 CPU,则设置XRT_DEVICE_MAP="CPU:0;/job:localservice/replica:0/task:0/device:XLA_CPU:0"
和XRT_WORKERS="localservice:0;grpc://:9002"
PJRT 运行时 (Beta)¶
此文档反映了当前 nightly 版本中 PJRT 支持的现状。查看 r2.0 分支上的相同文档,了解最新稳定版本的状况。
PyTorch/XLA 团队目前正在从当前支持的 XRT 运行时迁移到 PJRT 运行时,该运行时由 JAX 使用。
PJRT 可在 PyTorch/XLA 2.0 中预览。我们计划将 PJRT 作为我们的官方支持运行时,因此我们鼓励所有用户尝试使用它。我们的目标是在 2.1 版本中使 PJRT 稳定,因此如果您遇到 PJRT 的错误,请在 GitHub 上提交一个带有 runtime
标签的问题。
PyTorch/XLA r2.0 中的新功能:
如果您没有传入任何其他运行时配置,则默认情况下将配置 PJRT。如果您继续设置 XRT 配置 (
XRT_TPU_CONFIG
),此更改不会产生任何影响libtpu
中的新 TPU 运行时实现将性能提高了高达 30%。新的
xm.rendezvous
实现可扩展到数千个 TPU 内核[实验性]
torch.distributed
支持 TPU v2 和 v3,包括pjrt://
init_method
[实验性] PJRT 中的单主机 GPU 支持。多主机支持即将推出!
TL;DR¶
要使用 PJRT 预览运行时,请将
PJRT_DEVICE
环境变量设置为CPU
、TPU
或GPU
。在 XRT 中,所有分布式工作负载都是多进程的,每个设备一个进程。在 PJRT 中的 TPU v2 和 v3 上,工作负载是多进程和多线程的(每个进程 4 个进程,每个进程 2 个线程),因此您的工作负载应该线程安全。有关更多信息,请参阅 TPU v2/v3 上的多线程 和 API 指南中的多进程部分。请记住一些关键区别
要以线程安全的方式初始化模型,请在初始化后将参数广播到所有副本(
torch_xla.experimental.pjrt.broadcast_master_param
)或从公共检查点加载每个副本的参数。对于其他随机数生成,请尽可能使用
torch.Generator
。全局torch
RNG 不是线程安全的,即使您在所有副本中设置了相同的torch.manual_seed
。要使用
torch.distributed
,请导入torch_xla.experimental.pjrt_backend
并使用pjrt://
init_method
。这些步骤对于 GPU 和 TPU v4 是可选的。
从 XRT 到 PJRT 的示例差异
import os
import torch
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.optim as optim
import torch.distributed as dist
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_backend
import torch_xla.distributed.xla_multiprocessing as xmp
+import torch_xla.experimental.pjrt_backend
+import torch_xla.experimental.pjrt as pjrt
def _mp_fn(index):
device = xm.xla_device()
- dist.init_process_group('xla', rank=xm.get_ordinal(), world_size=xm.xrt_world_size())
+ dist.init_process_group('xla', init_method='pjrt://')
torch.manual_seed(42)
model = nn.Linear(128, 10).to(device)
+ # Optional for TPU v4 and GPU
+ pjrt.broadcast_master_param(model)
model = DDP(model, gradient_as_bucket_view=True)
loss_fn = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=.001)
for i in range(10):
data, target = torch.randn((128, 128), device=device), torch.randn((128, 10), device=device)
optimizer.zero_grad()
output = model(data)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
xm.mark_step()
# Print mean parameters so we can confirm they're the same across replicas
print([p.mean() for p in model.parameters()])
if __name__ == '__main__':
- os.environ['XRT_TPU_CONFIG'] = 'localservice;0;localhost:51011'
- os.environ['MASTER_ADDR'] = 'localhost'
- os.environ['MASTER_PORT'] = '12355'
+ # Recommended: set PJRT_DEVICE to your local device type
+ os.environ['PJRT_DEVICE'] = 'TPU'
xmp.spawn(_mp_fn)
优势¶
简单的运行时配置:只需将
PJRT_DEVICE
设置为TPU
、CPU
或GPU
,然后开始使用 XLA!或者,让 PJRT 根据您的环境自动选择设备。性能提升:gRPC 的开销减少,意味着端到端执行速度更快。在 TorchBench 2.0 上,我们在 TPU v4 上观察到训练时间提高了 35% 以上。
轻松执行 Pod:只需将您的代码复制到每个 TPU 工作器,并使用
gcloud compute tpus tpuvm ssh --worker=all
同时执行它们。更好的扩展性:消除了 XRT 对参数大小的限制,并支持高达 2048 个 TPU 芯片。
快速入门¶
要开始使用 PJRT 与 PyTorch/XLA,您只需设置 PJRT_DEVICE
环境变量。如果您正在使用 TPU v2 或 v3,请继续阅读以了解 TPU v2 和 v3 与 v4 之间的区别。
CPU¶
在任何安装了 PyTorch/XLA 的机器上,您都可以像这样在 CPU 上运行我们的 MNIST 示例
PJRT_DEVICE=CPU python3 xla/test/test_train_mp_mnist.py --fake_data
TPU¶
要创建安装了 PyTorch/XLA r1.13 的新 TPU
gcloud alpha compute tpus tpu-vm create $USER-pjrt --accelerator-type=v4-8 --version=tpu-vm-v4-pt-1.13 --zone=us-central2-b --project=$PROJECT
在 v4-8 上,您可以像这样运行我们的 ResNet50 示例
git clone --depth=1 --branch r1.13 https://github.com/pytorch/xla.git
PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=256 --num_epochs=1
默认情况下,PJRT 将使用所有 TPU 芯片。要仅使用一个 TPU 芯片,请配置 TPU_PROCESS_BOUNDS
和 TPU_VISIBLE_CHIPS
TPU_PROCESS_BOUNDS=1,1,1 TPU_VISIBLE_CHIPS=0 PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=256 --num_epochs=1
Pods¶
在 TPU Pods 上,使用 gcloud
在每个 TPU 上并行运行您的命令
gcloud alpha compute tpus tpu-vm ssh $USER-pjrt --zone=us-central2-b --project=$PROJECT --worker=all --command="git clone --depth=1 --branch r1.13 https://github.com/pytorch/xla.git"
gcloud alpha compute tpus tpu-vm ssh $USER-pjrt --zone=us-central2-b --project=$PROJECT --worker=all --command="PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=256 --num_epochs=1"
Docker¶
您也可以使用 Docker 在预安装了 PyTorch/XLA 的容器中运行您的工作负载
export DOCKER_IMAGE=gcr.io/...
# Optional: authenticate docker if your image is in a private GCP repository
gcloud compute tpus tpu-vm ssh $USER-pjrt --zone=us-central2-b --project=$PROJECT --worker=all --command "sudo gcloud auth configure-docker"
# Run your workload
gcloud compute tpus tpu-vm ssh $USER-pjrt --zone=us-central2-b --project=$PROJECT --worker=all --command "sudo docker run --rm --privileged --net=host -e PJRT_DEVICE=TPU $DOCKER_IMAGE python pytorch/xla/test/test_train_mp_imagenet.py --fake_data"
请注意,docker run
需要对主机有特权访问权限 (--privileged
) 才能将 TPU 设备公开给容器。目前,TPU Pod 上的 Docker 仅支持主机网络 --net=host
。有关更多信息,请参阅 Cloud TPU 文档。
GPU¶
警告:GPU 支持仍处于高度实验阶段!
要使用 PJRT 与 GPU,只需设置 PJRT_DEVICE=GPU
并将 GPU_NUM_DEVICES
配置为主机上的设备数量。例如
PJRT_DEVICE=GPU GPU_NUM_DEVICES=4 python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=128 --num_epochs=1
目前,仅支持单个主机,多主机 GPU 集群支持将在未来版本中添加。
与 XRT 的区别¶
尽管在大多数情况下,我们预计 PJRT 和 XRT 从最终用户的角度来看可以互换使用(尤其是在 TPU v4 上),但有一些细微的差别需要注意。重要的是,XRT 是围绕 TPU 节点架构设计的,因此它始终会在 TPU VM 上生成一个客户端和一个服务器进程。因此,每批输入都会因将数据序列化和反序列化以通过网络发送而产生额外的延迟。
PJRT 直接使用本地设备,无需中间服务器进程。在默认配置下,PJRT 将为每个 TPU 芯片创建一个进程,或为每个 TPU 主机创建 4 个进程。有关 TPU 架构的更多信息,请参阅 Cloud TPU 文档。
对于受 . 造成的开销限制的工作负载,可以提高性能。
在 XRT 下,服务器进程是唯一与 TPU 设备交互的进程,客户端进程无法直接访问 TPU 设备。在分析单主机 TPU(例如 v3-8 或 v4-8)时,通常会看到 8 个设备跟踪(每个 TPU 内核一个)。使用 PJRT,每个进程都拥有一个芯片,并且来自该进程的配置文件将仅显示 2 个 TPU 内核。
出于同样的原因,分析无法在使用 XRT 的 TPU Pod 上进行,因为服务器进程独立于用户的模型代码运行。PJRT 没有这种限制,因此可以在 TPU Pod 中为每个进程分析 2 个 TPU 内核。
PJRT 仅支持 TPU VM 架构,我们没有计划使用 PJRT 支持 TPU 节点架构。
使用 PJRT,运行时配置要简单得多。
xla_dist
不需要运行 TPU Pod 工作负载。相反,将您的代码复制到每个 TPU 主机([gcloud compute tpus tpu-vm scp](https://cloud.google.com/sdk/gcloud/reference/alpha/compute/tpus/tpu-vm/scp)
)并在每个主机上并行运行代码(例如[gcloud compute tpus tpu-vm ssh --workers=all --command="PJRT_DEVICE=TPU python run.py"](https://cloud.google.com/sdk/gcloud/reference/alpha/compute/tpus/tpu-vm/ssh)
)xm.rendezvous
已使用 XLA 本地集体通信重新实现,以增强大型 TPU Pod 的稳定性。有关更多详细信息,请参见下文。
TPU v2/v3 上的多线程¶
在 TPU v2 和 v3 上,分布式工作负载始终以多线程方式运行,因为每个 TPU 内核将两个 TPU 内核公开为设备,并且一次只能有一个进程打开一个 TPU 芯片。在默认配置下,xmp.spawn
会自动生成尽可能多的进程(每个 TPU 主机 4 个),并为每个进程创建两个线程(每个 TPU 内核一个)。
注意:在 TPU v4 上,每个 TPU 芯片都表示为一个 PyTorch 设备,因此分布式工作负载将在 4 个进程中运行,每个进程只有一个线程。这与 XRT 的行为相同。
在大多数情况下,这不需要对现有代码进行重大更改。在大多数情况下,您需要进行的主要更改是模型初始化。因为 torch
的全局 RNG 在线程之间共享,即使您在每个副本中将 torch.manual_seed
设置为相同的值,结果也会在线程和运行之间有所不同。为了在副本之间获得一致的参数,请使用 torch_xla.experimental.pjrt.broadcast_master_param
将一个副本的参数广播到所有其他副本,或者从公共检查点加载每个副本的参数。
xm.rendezvous 的更改¶
PyTorch/XLA r2.0 中的新功能
使用 XRT,worker 0 运行一个网格主服务,所有 worker 上的所有进程都通过 gRPC 连接到该服务。在实践中,我们发现由于 worker 0 的入站连接数量众多,在拥有数千个芯片的 TPU pod 上运行单个网格主进程不可靠。单个客户端进程超时可能会导致故障,并迫使整个工作负载重新启动。
因此,我们使用原生 XLA 集体通信重新实现了 xm.rendezvous
,它在大型 TPU pod 上更加稳定且经过良好测试。这与 XRT 实现相比,带来了两个新的约束
由于有效载荷必须成为 XLA 图的一部分,因此在数据传输之前和之后都会调用
xm.mark_step
。在模型代码中间调用xm.rendezvous
可能会强制进行不必要的编译。由于 XLA 不允许在部分 worker 上运行集体操作,因此所有 worker 必须参与
rendezvous
。
如果您需要 xm.rendezvous
的旧行为(即在不更改 XLA 图和/或同步部分 worker 的情况下通信数据),请考虑使用 ``torch.distributed.barrier` <https://pytorch.ac.cn/docs/stable/distributed.html#torch.distributed.barrier>`_ 或 [torch.distributed.all_gather_object](https://pytorch.ac.cn/docs/stable/distributed.html#torch.distributed.all_gather_object)
,并使用 gloo
进程组。如果您也使用 xla
torch.distributed
后端,可以使用 torch.new_group
创建一个 gloo
子组。请参阅 PyTorch 文档中的 此示例。请记住这些约束
torch.distributed
在 TPU v2/v3 上并未完全支持。只有使用xla
后端的某些操作已实现,并且gloo
在多进程上下文中可能无法按预期工作。在我们的实验中,
gloo
无法很好地扩展到数千个 TPU 芯片,因此预计这种替代方案在大型规模上不如使用xm.rendezvous
与 PJRT 可靠。
PJRT 和 torch.distributed¶
PyTorch/XLA r2.0 中的新功能
在使用 PJRT 与 torch.distributed
和 [torch.nn.parallel.DistributedDataParallel](https://github.com/pytorch/xla/blob/master/docs/ddp.md)
时,我们强烈建议使用新的 pjrt://
init_method
,它通过查询运行时自动查找副本 ID、世界大小和主 IP。例如
import torch
import torch.distributed as dist
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
from torch_xla.experimental import pjrt
# Required for `pjrt://` init_method
import torch_xla.experimental.pjrt_backend
def _all_gather(index: int):
# No need to pass in `rank` or `world_size`
dist.init_process_group('xla', init_method='pjrt://')
t = torch.tensor([index], dtype=torch.int32, device=xm.xla_device())
output = [torch.zeros_like(t) for _ in range(dist.get_world_size())]
dist.all_gather(output, t)
xm.mark_step()
print(output)
if __name__ == '__main__':
xmp.spawn(_all_gather)
注意:虽然在 TPU v4 上不需要 pjrt://
init_method,但仍然建议使用。如果您使用 env://
,则必须将 MASTER_ADDR
设置为具有设备 0 的 IP 主机,该主机 *不* 总是 worker 0。 pjrt://
init_method 会自动查找此 IP,并支持 TPU v2/v3。
有关在 PyTorch/XLA 上使用 DistributedDataParallel
的更多信息,请参阅 TPU V4 上的 ``ddp.md` <./ddp.md>`_。有关使用 DDP 和 PJRT 结合的示例,请在 TPU 上运行以下 示例脚本
PJRT_DEVICE=TPU python xla/test/test_train_mp_mnist.py --ddp --pjrt_distributed --fake_data --num_epochs 1
性能¶
TorchBench 显示,与 XRT 相比,PJRT 在跨任务的平均训练时间方面有所改进,在 TPU v4-8 上平均提高了 35% 以上。优势因任务和模型类型而异,范围从 0% 到 175%。以下图表显示了按任务细分的改进情况
新的 TPU 运行时¶
PyTorch/XLA r2.0 中的新功能
PyTorch/XLA r2.0 版本引入了对 PJRT 插件 API 的支持,用于访问 libtpu
中基于 TFRT 的新 TPU 运行时。现在,当设置 PJRT_DEVICE=TPU
时,这是默认运行时。2.0 版本中,使用 PJRT_DEVICE=TPU_LEGACY
仍然可以使用 1.13 中使用的基于 StreamExecutor 的旧版 TPU 运行时,但它将在未来版本中删除。如果您遇到仅在 TPU
上发生而不在 TPU_LEGACY
上发生的错误,请在 GitHub 上提交问题。
在大多数情况下,我们预计两种运行时的性能将相似,但在某些情况下,新运行时可能会快达 30%。以下图表显示了按任务细分的改进情况
注意:此图表中显示的改进也包含在 PJRT 与 XRT 的比较中。
PyTorch XLA 中的 TorchDynamo(torch.compile) 集成¶
Torchdynamo 是一种 Python 级别的 JIT 编译器,旨在使未修改的 PyTorch 程序更快。它为编译器后端提供了一个干净的 API,其最大特点是在执行 Python 字节码之前动态修改它。在 pytorch/xla 2.0 版本中,PyTorch/XLA 为 TorchDynamo 提供了一个实验性后端,用于推理和训练。
XLA 桥接的工作方式是,当 Dynamo 识别到模型模式时,它会提供一个 TorchFX 图,而 PyTorch/XLA 将使用现有的 Lazy Tensor 技术来编译 FX 图并返回已编译的函数。
推理¶
以下是用 torch.compile
运行 resnet18 的一个小代码示例。
import torch
imprt torchvision
import torch_xla.core.xla_model as xm
def eval_model(loader):
device = xm.xla_device()
xla_resnet18 = torchvision.models.resnet18().to(device)
xla_resnet18.eval()
dynamo_resnet18 = torch.compile(
xla_resnet18, backend='torchxla_trace_once')
for data, _ in loader:
output = dynamo_resnet18(data)
注意: 推理后端名称
torchxla_trace_once
可能会有所更改。
使用 torch.compile
,您会发现 PyTorch/XLA 只会在初始化时跟踪一次 resent18 模型,并在每次调用 dynamo_resnet18
时执行已编译的二进制文件,而不是每次都跟踪模型。请注意,目前 Dynamo 不支持回退,因此如果存在 XLA 无法跟踪的操作,它将出错。我们将在即将发布的 2.1 版本中修复此问题。以下是在 Cloud TPU v4-8 上使用 torch bench 对 Dynamo 和 Lazy 进行推理速度分析的比较结果。
resnet18 | 1.768 resnet50 | 1.61 resnext50_32x4d | 1.328 alexnet | 1.261 mobilenet_v2 | 2.017 mnasnet1_0 | 1.686 vgg16 | 1.155 BERT_pytorch | 3.502 squeezenet1_1 | 1.674 timm_vision_transformer | 3.138 平均 | 1.9139
训练¶
PyTorch/XLA 也支持 Dynamo 进行训练,但它处于非常实验阶段,我们正在与 PyTorch Compiler 团队合作迭代实现。在 2.0 版本中,它只支持前向和后向传递,而不支持优化器。以下是用 torch.compile
训练 resnet18 的示例。
import torch
imprt torchvision
import torch_xla.core.xla_model as xm
def train_model(model, data, target):
loss_fn = torch.nn.CrossEntropyLoss()
pred = model(data)
loss = loss_fn(pred, target)
loss.backward()
return pred
def train_model_main(loader):
device = xm.xla_device()
xla_resnet18 = torchvision.models.resnet18().to(device)
xla_resnet18.train()
dynamo_train_model = torch.compile(
train_model, backend='aot_torchxla_trace_once')
for data, target in loader:
output = dynamo_train_model(xla_resnet18, data, target)
注意: 我们在此使用的后端是
aot_torchxla_trace_once
(可能会有所更改),而不是torchxla_trace_once
。
我们预计在每个训练步骤中提取和执行 3 个图,而不是使用 Lazy Tensor 时的一个训练步骤。以下是在 Cloud TPU v4-8 上使用 torch bench 对 Dynamo 和 Lazy 进行训练速度分析的比较结果。
resnet50 | 0.937 resnet18 | 1.003 BERT_pytorch | 1.869 resnext50_32x4d | 1.139 alexnet | 0.802 mobilenet_v2 | 0.672 mnasnet1_0 | 0.967 vgg16 | 0.742 timm_vision_transformer | 1.69 squeezenet1_1 | 0.958 平均 | 1.0779
注意: 我们为每个模型的 fwd 和 bwd 运行单个步骤,然后收集 e2e 时间。在现实世界中,我们将在每个训练作业中运行多个步骤,这可以轻松地隐藏执行中的跟踪成本(因为它是非同步的)。Lazy Tensor 在这种情况下将具有更好的性能。
我们目前正在研究优化器支持,它将在 nightly 版本中很快提供,但不会在 2.0 版本中提供。
结论¶
TorchDynamo 为编译器后端提供了一种非常有前景的方式,可以隐藏用户端的复杂性,并轻松地以图形格式检索模型代码。与 PyTorch/XLA 的传统 Lazy Tensor 提取图形的方式相比,TorchDynamo 可以跳过每次迭代的图形跟踪,从而提供更好的推理响应时间。但是,TorchDynamo 尚未跟踪通信操作(如 all_reduce
和 all_gather
),并且它为前向和反向提供单独的图形,这会影响 xla 性能。与 Lazy Tensor 相比,这些功能差距使其在实际训练用例中效率较低,特别是跟踪成本可以在训练中与执行重叠。PyTorch/XLA 团队将继续投资 TorchDynamo,并与上游合作以完善训练故事。
PyTorch XLA 中的全分片数据并行 (FSDP)¶
PyTorch XLA 中的全分片数据并行 (FSDP) 是一种用于在数据并行工作器之间分片模块参数的实用程序。
使用示例
import torch
import torch_xla.core.xla_model as xm
from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP
model = FSDP(my_module)
optim = torch.optim.Adam(model.parameters(), lr=0.0001)
output = model(x, y)
loss = output.sum()
loss.backward()
optim.step()
也可以分别分片单个层,并让外部包装器处理任何剩余的参数。
注意事项
XlaFullyShardedDataParallel
类支持 ZeRO-2 优化器(分片梯度和优化器状态)和 ZeRO-3 优化器(分片参数、梯度和优化器状态),如 https://arxiv.org/abs/1910.02054 中所述。ZeRO-3 优化器应通过嵌套 FSDP 实现,并使用
reshard_after_forward=True
。有关示例,请参见test/test_train_mp_mnist_fsdp_with_ckpt.py
和test/test_train_mp_imagenet_fsdp.py
。对于无法容纳在单个 TPU 内存或主机 CPU 内存中的大型模型,应将子模块构建与内部 FSDP 包装交织在一起。有关示例,请参见 ``FSDPViTModel` <https://github.com/ronghanghu/vit_10b_fsdp_example/blob/master/run_vit_training.py>`_。
提供了一个简单的包装器
checkpoint_module
(基于torch_xla.utils.checkpoint.checkpoint
,来自 https://github.com/pytorch/xla/pull/3524),用于对给定的nn.Module
实例执行 梯度检查点。有关示例,请参见test/test_train_mp_mnist_fsdp_with_ckpt.py
和test/test_train_mp_imagenet_fsdp.py
。自动包装子模块:除了手动嵌套 FSDP 包装之外,还可以指定一个
auto_wrap_policy
参数,以使用内部 FSDP 自动包装子模块。size_based_auto_wrap_policy
在torch_xla.distributed.fsdp.wrap
中是auto_wrap_policy
可调用的一个示例,此策略包装参数数量大于 100M 的层。transformer_auto_wrap_policy
在torch_xla.distributed.fsdp.wrap
中是auto_wrap_policy
可调用的一个示例,适用于类似 Transformer 的模型架构。
例如,要使用内部 FSDP 自动包装所有 torch.nn.Conv2d
子模块,可以使用
from torch_xla.distributed.fsdp.wrap import transformer_auto_wrap_policy
auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={torch.nn.Conv2d})
此外,还可以指定一个 auto_wrapper_callable
参数来使用自定义的可调用包装器来包装子模块(默认包装器只是 XlaFullyShardedDataParallel
类本身)。例如,可以使用以下方法将梯度检查点(即激活检查点/重计算)应用于每个自动包装的子模块。
from torch_xla.distributed.fsdp import checkpoint_module
auto_wrapper_callable = lambda m, *args, **kwargs: XlaFullyShardedDataParallel(
checkpoint_module(m), *args, **kwargs)
在执行优化器步骤时,直接调用
optimizer.step
,不要调用xm.optimizer_step
。后者会跨进程减少梯度,这对于 FSDP(参数已经分片)来说是不必要的。在训练期间保存模型和优化器检查点时,每个训练进程都需要保存其自己的(分片)模型和优化器状态字典的检查点(使用
master_only=False
并在xm.save
中为每个进程设置不同的路径)。恢复时,需要加载对应进程的检查点。请同时保存
model.get_shard_metadata()
和model.state_dict()
,如下所示,并使用consolidate_sharded_model_checkpoints
将分片模型检查点拼接成完整的模型状态字典。有关示例,请参见test/test_train_mp_mnist_fsdp_with_ckpt.py
。.. code-block:: python3- ckpt = {
‘model’: model.state_dict(), ‘shard_metadata’: model.get_shard_metadata(), ‘optimizer’: optimizer.state_dict(),
} ckpt_path = f’/tmp/rank-{xm.get_ordinal()}-of-{xm.xrt_world_size()}.pth’ xm.save(ckpt, ckpt_path, master_only=False)
检查点合并脚本也可以从命令行启动,如下所示。.. code-block:: bash
# 通过命令行工具合并保存的检查点 python3 -m torch_xla.distributed.fsdp.consolidate_sharded_ckpts –ckpt_prefix /path/to/your_sharded_checkpoint_files –ckpt_suffix “_rank--of-.pth”
此类的实现很大程度上受到 https://fairscale.readthedocs.io/en/stable/api/nn/fsdp.html 中的 fairscale.nn.FullyShardedDataParallel
的启发,并且在很大程度上遵循其结构。与 fairscale.nn.FullyShardedDataParallel
最大的区别在于,在 XLA 中,我们没有显式参数存储,因此在这里我们采用不同的方法来释放完整的参数以用于 ZeRO-3。
MNIST 和 ImageNet 上的示例训练脚本¶
MNIST: ``test/test_train_mp_mnist_fsdp_with_ckpt.py` <https://github.com/pytorch/xla/blob/master/test/test_train_mp_mnist_fsdp_with_ckpt.py>`_ (它还测试了检查点合并)
ImageNet: ``test/test_train_mp_imagenet_fsdp.py` <https://github.com/pytorch/xla/blob/master/test/test_train_mp_imagenet_fsdp.py>`_
安装¶
FSDP 在 PyTorch/XLA 1.12 版本和更新的 nightly 版本中可用。请参考 https://github.com/pytorch/xla#-available-images-and-wheels 获取安装指南。
克隆 PyTorch/XLA 仓库¶
git clone --recursive https://github.com/pytorch/pytorch
cd pytorch/
git clone --recursive https://github.com/pytorch/xla.git
cd ~/
在 v3-8 TPU 上训练 MNIST¶
它在 2 个 epoch 内获得了大约 98.9 的准确率
python3 ~/pytorch/xla/test/test_train_mp_mnist_fsdp_with_ckpt.py \
--batch_size 16 --drop_last --num_epochs 2 \
--use_nested_fsdp --use_gradient_checkpointing
此脚本在结束时自动测试检查点合并。您也可以通过以下方式手动合并分片检查点
# consolidate the saved checkpoints via command line tool
python3 -m torch_xla.distributed.fsdp.consolidate_sharded_ckpts \
--ckpt_prefix /tmp/mnist-fsdp/final_ckpt \
--ckpt_suffix "_rank-*-of-*.pth"
在 v3-8 TPU 上使用 ResNet-50 训练 ImageNet¶
它在 100 个 epoch 内获得了大约 75.9 的准确率;将 ImageNet-1k 下载到 /datasets/imagenet-1k
python3 ~/pytorch/xla/test/test_train_mp_imagenet_fsdp.py \
--datadir /datasets/imagenet-1k --drop_last \
--model resnet50 --test_set_batch_size 64 --eval_interval 10 \
--lr 0.4 --batch_size 128 --num_warmup_epochs 5 --lr_scheduler_divide_every_n_epochs 30 --lr_scheduler_divisor 10 --num_epochs 100 \
--use_nested_fsdp
您还可以添加 --use_gradient_checkpointing
(需要与 --use_nested_fsdp
或 --auto_wrap_policy
一起使用)来在残差块上应用梯度检查点。
在 TPU pod 上的示例训练脚本(具有 100 亿个参数)¶
要训练无法放入单个 TPU 的大型模型,应在构建整个模型时应用自动包装或手动包装带有内部 FSDP 的子模块,以实现 ZeRO-3 算法。
请参阅 https://github.com/ronghanghu/vit_10b_fsdp_example,了解使用此 XLA FSDP PR 对 Vision Transformer (ViT) 模型进行分片训练的示例。
如何执行 DistributedDataParallel
¶
本文档展示了如何在 xla 中使用 torch.nn.parallel.DistributedDataParallel,并进一步描述了它与原生 xla 数据并行方法的区别。
背景/动机¶
客户长期以来一直要求能够将 PyTorch 的 DistributedDataParallel API 与 xla 一起使用。在这里,我们将其作为一项实验性功能启用。
如何使用 DistributedDataParallel¶
对于那些从 PyTorch 急切模式切换到 XLA 的用户,以下是在将您的急切 DDP 模型转换为 XLA 模型时需要进行的所有更改。我们假设您已经知道如何在单个设备上使用 XLA on a single device。
导入 xla 特定的分布式包
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_backend
初始化 xla 进程组,类似于其他进程组,如 nccl 和 gloo。
dist.init_process_group("xla", rank=rank, world_size=world_size)
如果您需要,请使用 xla 特定的 API 获取 rank 和 world_size。
new_rank = xm.get_ordinal()
world_size = xm.xrt_world_size()
将
gradient_as_bucket_view=True
传递给 DDP 包装器。
ddp_model = DDP(model, gradient_as_bucket_view=True)
最后,使用 xla 特定的启动器启动您的模型。
xmp.spawn(demo_fn)
在这里,我们将所有内容整合在一起(该示例实际上取自 DDP 教程)。您编写代码的方式与急切体验非常相似。只是在单个设备上使用 xla 特定的操作,以及对您的脚本进行上述五项更改。
import os
import sys
import tempfile
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP
# additional imports for xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_backend
import torch_xla.distributed.xla_multiprocessing as xmp
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
# initialize the xla process group
dist.init_process_group("xla", rank=rank, world_size=world_size)
def cleanup():
dist.destroy_process_group()
class ToyModel(nn.Module):
def __init__(self):
super(ToyModel, self).__init__()
self.net1 = nn.Linear(10, 1000000)
self.relu = nn.ReLU()
self.net2 = nn.Linear(1000000, 5)
def forward(self, x):
return self.net2(self.relu(self.net1(x)))
def demo_basic(rank):
# xla specific APIs to get rank, world_size.
new_rank = xm.get_ordinal()
assert new_rank == rank
world_size = xm.xrt_world_size()
print(f"Running basic DDP example on rank {rank}.")
setup(rank, world_size)
# create model and move it to XLA device
device = xm.xla_device()
model = ToyModel().to(device)
# currently, graident_as_bucket_view is needed to make DDP work for xla
ddp_model = DDP(model, gradient_as_bucket_view=True)
loss_fn = nn.MSELoss()
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
optimizer.zero_grad()
outputs = ddp_model(torch.randn(20, 10).to(device))
labels = torch.randn(20, 5).to(device)
loss_fn(outputs, labels).backward()
optimizer.step()
# xla specific API to execute the graph
xm.mark_step()
cleanup()
def run_demo(demo_fn):
# xla specific launcher
xmp.spawn(demo_fn)
if __name__ == "__main__":
run_demo(demo_basic)
基准测试¶
使用假数据的 Resnet50¶
以下结果是在 TPU VM V3-8 环境中使用 ToT PyTorch 和 PyTorch/XLA 收集的,命令为:python test/test_train_mp_imagenet.py --fake_data --model=resnet50 --num_epochs=1
。统计指标是使用此 pull request 中的脚本生成的。速率单位为每秒图像数。
类型 | 平均值 | 中位数 | 第 90 个百分位 | 标准差 | 变异系数 |
xm.optimizer_step | 418.54 | 419.22 | 430.40 | 9.76 | 0.02 |
DDP | 395.97 | 395.54 | 407.13 | 7.60 | 0.02 |
我们原生分布式数据并行方法与 DistributedDataParallel 包装器之间的性能差异为:1 - 395.97 / 418.54 = 5.39%。考虑到 DDP 包装器在跟踪 DDP 运行时引入了额外的开销,这个结果似乎是合理的。
使用伪造数据训练 MNIST¶
以下结果是在 TPU VM V3-8 环境中使用 ToT PyTorch 和 PyTorch/XLA 收集的,命令为:python test/test_train_mp_mnist.py --fake_data
。统计指标是使用此 pull request 中的脚本生成的。速率单位为每秒图像。
类型 | 平均值 | 中位数 | 第 90 个百分位 | 标准差 | 变异系数 |
xm.optimizer_step | 17864.19 | 20108.96 | 24351.74 | 5866.83 | 0.33 |
DDP | 10701.39 | 11770.00 | 14313.78 | 3102.92 | 0.29 |
我们原生分布式数据并行方法与 DistributedDataParallel 包装器之间的性能差异为:1 - 14313.78 / 24351.74 = 41.22%。由于数据集较小,并且前几轮受到数据加载的严重影响,因此我们比较的是 90% 的值。这种速度下降很大,但考虑到模型很小,这是有道理的。额外的 DDP 运行时跟踪开销很难摊销。
使用真实数据训练 MNIST¶
以下结果是在 TPU VM V3-8 环境中使用 ToT PyTorch 和 PyTorch/XLA 收集的,命令为:python test/test_train_mp_mnist.py --logdir mnist/
。

我们可以观察到,尽管 DDP 包装器在最后仍然达到了 97.48% 的高准确率,但它的收敛速度比原生 XLA 方法慢。(原生方法达到了 99% 的准确率。)
免责声明¶
此功能仍处于实验阶段,并且正在积极开发中。请谨慎使用,并随时将任何错误报告给 xla github 仓库。对于那些对原生 xla 数据并行方法感兴趣的人,这里有 教程。
以下是一些正在调查中的已知问题
gradient_as_bucket_view=True
需要强制执行。在使用
torch.utils.data.DataLoader
时遇到了一些问题。test_train_mp_mnist.py
使用真实数据在退出之前崩溃。
如何使用 PyTorch/XLA:GPU¶
PyTorch/XLA 使 PyTorch 用户能够利用 XLA 编译器,该编译器支持包括 TPU、GPU、CPU 等在内的加速器。本文档将介绍在 nvidia gpu 实例上运行 PyTorch/XLA 的基本步骤。
创建 GPU 实例¶
Pytorch/XLA 目前发布了预构建的 docker 镜像和轮子,支持 cuda11.2 和 python 3.7/3.8。我们建议用户使用相应的配置创建 GPU 实例。有关 docker 镜像和轮子的完整列表,请参考 此文档。
环境设置¶
Docker¶
sudo docker pull gcr.io/tpu-pytorch/xla:nightly_3.8_cuda_11.2
sudo apt-get install -y apt-transport-https ca-certificates curl gnupg-agent software-properties-common
distribution=$(. /etc/os-release;echo $ID$VERSION_ID)
curl -s -L https://nvda.org.cn/nvidia-docker/gpgkey | sudo apt-key add -
curl -s -L https://nvda.org.cn/nvidia-docker/$distribution/nvidia-docker.list | sudo tee /etc/apt/sources.list.d/nvidia-docker.list
sudo apt-get update && sudo apt-get install -y nvidia-container-toolkit
sudo systemctl restart docker
sudo docker run --gpus all -it -d gcr.io/tpu-pytorch/xla:nightly_3.7\8_cuda_11.2 bin/bash
sudo docker exec -it $(sudo docker ps | awk 'NR==2 { print $1 }') /bin/bash
请注意,您需要重启 docker 才能使 docker 容器中的 gpu 设备可见。登录 docker 后,您可以使用 nvidia-smi
验证设备是否已正确设置。
(pytorch) root@20ab2c7a2d06:/# nvidia-smi
Thu Dec 8 06:24:29 2022
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 510.47.03 Driver Version: 510.47.03 CUDA Version: 11.6 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 Tesla V100-SXM2... Off | 00000000:00:04.0 Off | 0 |
| N/A 36C P0 38W / 300W | 0MiB / 16384MiB | 1% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=============================================================================|
| No running processes found |
+-----------------------------------------------------------------------------+
轮子¶
pip3 install torch=1.13
pip3 install https://storage.googleapis.com/tpu-pytorch/wheels/cuda/112/torch_xla-1.13-cp37-cp37m-linux_x86_64.whl
运行一个简单的模型¶
为了运行以下示例,您需要克隆 pytorch/xla 仓库以访问 imagenet 示例(我们已经在 docker 中克隆了它)。
(pytorch) root@20ab2c7a2d06:/# export GPU_NUM_DEVICES=1
(pytorch) root@20ab2c7a2d06:/# python pytorch/xla/test/test_train_mp_imagenet.py --fake_data
==> Preparing data..
Epoch 1 train begin 06:12:38
2022-12-08 06:13:12.452874: W 79 tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc:729] None of the algorithms provided by cuDNN heuristics worked; trying fallback algorithms. Conv: (f32[128,256,28,28]{3,2,1,0}, u8[0]{0}) custom-call(f32[128,256,14,14]{3,2,1,0}, f32[3,3,256,256]{1,0,2,3}), window={size=3x3 stride=2x2 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, custom_call_target="__cudnn$convBackwardInput", backend_config="{\"conv_result_scale\":1,\"activation_mode\":\"0\",\"side_input_scale\":0}"
2022-12-08 06:13:13.780992: W 79 tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc:729] None of the algorithms provided by cuDNN heuristics worked; trying fallback algorithms. Conv: (f32[128,128,56,56]{3,2,1,0}, u8[0]{0}) custom-call(f32[128,128,28,28]{3,2,1,0}, f32[3,3,128,128]{1,0,2,3}), window={size=3x3 stride=2x2 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, custom_call_target="__cudnn$convBackwardInput", backend_config="{\"conv_result_scale\":1,\"activation_mode\":\"0\",\"side_input_scale\":0}"
| Training Device=xla:0/0 Epoch=1 Step=0 Loss=6.89059 Rate=2.82 GlobalRate=2.82 Time=06:13:23
| Training Device=xla:0/0 Epoch=1 Step=20 Loss=6.79297 Rate=117.16 GlobalRate=45.84 Time=06:13:36
| Training Device=xla:0/0 Epoch=1 Step=40 Loss=6.43628 Rate=281.16 GlobalRate=80.49 Time=06:13:43
| Training Device=xla:0/0 Epoch=1 Step=60 Loss=5.83108 Rate=346.88 GlobalRate=108.82 Time=06:13:49
| Training Device=xla:0/0 Epoch=1 Step=80 Loss=4.99023 Rate=373.62 GlobalRate=132.43 Time=06:13:56
| Training Device=xla:0/0 Epoch=1 Step=100 Loss=3.92699 Rate=384.33 GlobalRate=152.40 Time=06:14:02
| Training Device=xla:0/0 Epoch=1 Step=120 Loss=2.68816 Rate=388.35 GlobalRate=169.49 Time=06:14:09
AMP(自动混合精度)¶
AMP 在 GPU 训练中非常有用,PyTorch/XLA 重用 Cuda 的 AMP 规则。您可以查看我们的 mnist 示例 和 imagenet 示例。请注意,我们还使用了一个修改版本的 优化器 来避免设备和主机之间额外的同步。