快捷方式

PyTorch/XLA 文档

PyTorch/XLA 是一个 Python 包,它使用 XLA 深度学习编译器来连接 PyTorch 深度学习框架和 Cloud TPU。

XLA 设备上的 PyTorch

PyTorch 使用 torch_xla 包 在 XLA 设备(如 TPU)上运行。本文档介绍了如何在这些设备上运行模型。

创建 XLA 张量

PyTorch/XLA 为 PyTorch 添加了一种新的 xla 设备类型。此设备类型的工作方式与其他 PyTorch 设备类型相同。例如,以下是如何创建和打印 XLA 张量

import torch
import torch_xla
import torch_xla.core.xla_model as xm

t = torch.randn(2, 2, device=xm.xla_device())
print(t.device)
print(t)

此代码应该看起来很熟悉。PyTorch/XLA 使用与常规 PyTorch 相同的接口,并增加了一些功能。导入 torch_xla 初始化 PyTorch/XLA,而 xm.xla_device() 返回当前的 XLA 设备。根据您的环境,这可能是 CPU 或 TPU。

XLA 张量是 PyTorch 张量

PyTorch 操作可以像在 CPU 或 CUDA 张量上一样在 XLA 张量上执行。

例如,可以将 XLA 张量加在一起

t0 = torch.randn(2, 2, device=xm.xla_device())
t1 = torch.randn(2, 2, device=xm.xla_device())
print(t0 + t1)

或进行矩阵乘法

print(t0.mm(t1))

或与神经网络模块一起使用

l_in = torch.randn(10, device=xm.xla_device())
linear = torch.nn.Linear(10, 20).to(xm.xla_device())
l_out = linear(l_in)
print(l_out)

与其他设备类型一样,XLA 张量仅适用于同一设备上的其他 XLA 张量。因此,以下代码

l_in = torch.randn(10, device=xm.xla_device())
linear = torch.nn.Linear(10, 20)
l_out = linear(l_in)
print(l_out)
# Input tensor is not an XLA tensor: torch.FloatTensor

将引发错误,因为 torch.nn.Linear 模块位于 CPU 上。

在 XLA 设备上运行模型

构建新的 PyTorch 网络或转换现有网络以在 XLA 设备上运行只需要几行 XLA 特定代码。以下代码片段在单个设备和使用 XLA 多处理的多个设备上运行时突出显示了这些行。

在单个 XLA 设备上运行

以下代码片段显示了一个在单个 XLA 设备上训练的网络

import torch_xla.core.xla_model as xm

device = xm.xla_device()
model = MNIST().train().to(device)
loss_fn = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)

for data, target in train_loader:
  optimizer.zero_grad()
  data = data.to(device)
  target = target.to(device)
  output = model(data)
  loss = loss_fn(output, target)
  loss.backward()

  optimizer.step()
  xm.mark_step()

此代码片段突出显示了将模型切换到在 XLA 上运行有多么容易。模型定义、数据加载器、优化器和训练循环可以在任何设备上运行。唯一的 XLA 特定代码是获取 XLA 设备并标记步骤的几行代码。在每次训练迭代结束时调用 xm.mark_step() 会导致 XLA 执行其当前图并更新模型的参数。有关 XLA 如何创建图和运行操作的更多信息,请参阅 XLA 张量深入解读

使用多处理在多个 XLA 设备上运行

PyTorch/XLA 使通过在多个 XLA 设备上运行来加速训练变得容易。以下代码片段显示了如何实现

import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp

def _mp_fn(index):
  device = xm.xla_device()
  mp_device_loader = pl.MpDeviceLoader(train_loader, device)

  model = MNIST().train().to(device)
  loss_fn = nn.NLLLoss()
  optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)

  for data, target in mp_device_loader:
    optimizer.zero_grad()
    output = model(data)
    loss = loss_fn(output, target)
    loss.backward()
    xm.optimizer_step(optimizer)

if __name__ == '__main__':
  xmp.spawn(_mp_fn, args=())

此多设备代码片段与之前的单设备代码片段之间存在三个差异。让我们逐一介绍。

  • xmp.spawn()

    • 创建每个运行 XLA 设备的进程。

    • 每个进程只能访问分配给当前进程的设备。例如,在 TPU v4-8 上,将启动 4 个进程,每个进程将拥有一个 TPU 设备。

    • 请注意,如果您在每个进程上打印 xm.xla_device(),您将在所有设备上看到 xla:0。这是因为每个进程只能看到一个设备。这并不意味着多进程功能失效。在 TPU v2 和 TPU v3 上,只有 PJRT 运行时才能执行,因为将有 #devices/2 个进程,并且每个进程将有 2 个线程(有关更多详细信息,请查看此 文档)。

  • MpDeviceLoader

    • 将训练数据加载到每个设备上。

    • MpDeviceLoader 可以包装在 PyTorch 数据加载器上。它可以将数据预加载到设备上,并将数据加载与设备执行重叠,以提高性能。

    • MpDeviceLoader 还会为您在每个 batches_per_execution(默认为 1)批次被生成时调用 xm.mark_step

  • xm.optimizer_step(optimizer)

    • 合并设备之间的梯度并发出 XLA 设备步骤计算。

    • 它基本上是一个 all_reduce_gradients + optimizer.step() + mark_step,并返回被约简的损失。

模型定义、优化器定义和训练循环保持不变。

**注意:** 重要的是要注意,在使用多处理时,用户只能从 xmp.spawn() 的目标函数(或任何以 xmp.spawn() 作为父函数的函数)内部开始检索和访问 XLA 设备。

有关在多个 XLA 设备上使用多处理训练网络的更多信息,请参阅 完整的多处理示例

在 TPU Pod 上运行

不同加速器的多主机设置可能非常不同。本文档将讨论多主机训练的设备无关部分,并以 TPU + PJRT 运行时(目前在 1.13 和 2.x 版本中可用)为例。

在开始之前,请查看我们位于 此处 的用户指南,其中将解释一些 Google Cloud 基础知识,例如如何使用 gcloud 命令以及如何设置项目。您还可以查看 此处 以获取所有 Cloud TPU 操作指南。本文档将重点关注 PyTorch/XLA 设置的视角。

假设您在 train_mnist_xla.py 中拥有上述 mnist 示例。如果它是单主机多设备训练,您将 ssh 到 TPUVM 并运行如下命令

PJRT_DEVICE=TPU python3 train_mnist_xla.py

现在,为了在 TPU v4-16(它有两个主机,每个主机有 4 个 TPU 设备)上运行相同的模型,您需要

  • 确保每个主机都可以访问训练脚本和训练数据。这通常使用 gcloud scp 命令或 gcloud ssh 命令将训练脚本复制到所有主机来完成。

  • 在所有主机上同时运行相同的训练命令。

gcloud alpha compute tpus tpu-vm ssh $USER-pjrt --zone=$ZONE --project=$PROJECT --worker=all --command="PJRT_DEVICE=TPU python3 train_mnist_xla.py"

上面的 gcloud ssh 命令将 ssh 到 TPUVM Pod 中的所有主机,并同时运行相同的命令。

**注意:** 您需要在 TPUVM VM 外部运行上述 gcloud 命令。

模型代码和训练脚本对于多进程训练和多主机训练是相同的。PyTorch/XLA 和底层基础架构将确保每个设备都了解全局拓扑以及每个设备的本地和全局序号。跨设备通信将在所有设备之间发生,而不是在本地设备之间发生。

有关 PJRT 运行时以及如何在 pod 上运行它的更多详细信息,请参阅此 文档。有关 PyTorch/XLA 和 TPU pod 的更多信息以及在 TPU pod 上运行 resnet50 和伪数据的完整指南,请参阅此 指南

XLA 张量深入解读

使用 XLA 张量和设备只需要更改几行代码。但即使 XLA 张量的行为非常类似于 CPU 和 CUDA 张量,它们的内部实现也不相同。本节介绍使 XLA 张量独一无二的特点。

XLA 张量是惰性的

CPU 和 CUDA 张量会立即或急切地启动操作。另一方面,XLA 张量是惰性的。它们将操作记录在一个图中,直到需要结果。像这样延迟执行可以让 XLA 对其进行优化。例如,多个独立操作的图可能会融合成单个优化的操作。

惰性执行通常对调用者是不可见的。PyTorch/XLA 自动构建图,将它们发送到 XLA 设备,并在 XLA 设备和 CPU 之间复制数据时进行同步。在显式同步 CPU 和 XLA 设备时插入屏障。有关我们惰性张量设计的更多信息,您可以阅读这篇论文

XLA 张量和 bFloat16

在 TPU 上运行时,PyTorch/XLA 可以使用bfloat16 数据类型。事实上,PyTorch/XLA 在 TPU 上对浮点类型(torch.floattorch.double)的处理方式不同。此行为由 XLA_USE_BF16XLA_DOWNCAST_BF16 环境变量控制。

  • 默认情况下,torch.floattorch.double 在 TPU 上都是 torch.float

  • 如果设置了 XLA_USE_BF16,则 torch.floattorch.double 在 TPU 上都为 bfloat16

  • 如果设置了 XLA_DOWNCAST_BF16,则 torch.float 在 TPU 上为 bfloat16,而 torch.double 在 TPU 上为 float32

  • 如果 PyTorch 张量具有 torch.bfloat16 数据类型,则它将直接映射到 TPU bfloat16(XLA BF16 原语类型)。

开发人员应注意,TPU 上的 XLA 张量将始终报告其 PyTorch 数据类型,而不管它们实际使用的数据类型是什么。此转换是自动且不透明的。如果 TPU 上的 XLA 张量移回 CPU,它将从其实际数据类型转换为其 PyTorch 数据类型。根据代码的操作方式,处理单元类型触发的此转换可能很重要。

内存布局

XLA 张量的内部数据表示对用户是不透明的。它们不公开其存储,并且始终看起来是连续的,这与 CPU 和 CUDA 张量不同。这允许 XLA 调整张量的内存布局以提高性能。

将 XLA 张量移入和移出 CPU

XLA 张量可以从 CPU 移动到 XLA 设备,也可以从 XLA 设备移动到 CPU。如果移动视图,则其正在查看的数据也会复制到另一个设备,并且不会保留视图关系。换句话说,一旦数据复制到另一个设备,它就与之前的设备或其上的任何张量没有关系。同样,根据代码的操作方式,理解和适应此转换可能很重要。

保存和加载 XLA 张量

应在保存之前将 XLA 张量移动到 CPU,如下面的代码片段所示

import torch
import torch_xla
import torch_xla.core.xla_model as xm

device = xm.xla_device()

t0 = torch.randn(2, 2, device=device)
t1 = torch.randn(2, 2, device=device)

tensors = (t0.cpu(), t1.cpu())

torch.save(tensors, 'tensors.pt')

tensors = torch.load('tensors.pt')

t0 = tensors[0].to(device)
t1 = tensors[1].to(device)

这使您可以将加载的张量放在任何可用的设备上,而不仅仅是初始化它们的设备。

根据上面关于将 XLA 张量移动到 CPU 的说明,在处理视图时必须小心。建议您在张量加载并移动到其目标设备后重新创建它们,而不是保存视图。

提供了一个实用程序 API 来保存数据,方法是先将其移动到 CPU

import torch
import torch_xla
import torch_xla.core.xla_model as xm

xm.save(model.state_dict(), path)

在多个设备的情况下,上述 API 将仅保存主设备序号 (0) 的数据。

在内存相对于模型参数大小有限的情况下,提供了一个 API 来减少主机上的内存占用

import torch_xla.utils.serialization as xser

xser.save(model.state_dict(), path)

此 API 将 XLA 张量一次一个地流式传输到 CPU,减少了使用的主机内存量,但它需要一个匹配的加载 API 来恢复

import torch_xla.utils.serialization as xser

state_dict = xser.load(path)
model.load_state_dict(state_dict)

可以直接保存 XLA 张量,但不建议这样做。XLA 张量始终加载回保存它们的设备,如果该设备不可用,则加载将失败。PyTorch/XLA 与 PyTorch 的所有其他部分一样,处于积极开发中,此行为将来可能会更改。

编译缓存

XLA 编译器将跟踪的 HLO 转换为在设备上运行的可执行文件。编译可能很耗时,并且在 HLO 在执行之间没有更改的情况下,可以将编译结果持久化到磁盘以供重用,从而显着减少开发迭代时间。

请注意,如果 HLO 在执行之间发生更改,则仍会发生重新编译。

这目前是一个实验性的选择加入 API,必须在执行任何计算之前激活它。初始化通过 initialize_cache API 完成

import torch_xla.runtime as xr
xr.initialize_cache('YOUR_CACHE_PATH', readonly=False)

这将在指定路径初始化一个持久性编译缓存。 readonly 参数可用于控制工作程序是否能够写入缓存,这在使用共享缓存挂载进行 SPMD 工作负载时非常有用。

进一步阅读

其他文档可在PyTorch/XLA 存储库中找到。有关在 TPU 上运行网络的更多示例,请访问此处

PyTorch/XLA API

torch_xla

torch_xla.device(index: Optional[int] = None) device[source]

返回给定 XLA 设备的实例。

如果启用了 SPMD,则返回一个包装此进程可用的所有设备的虚拟设备。

参数

index – 要返回的 XLA 设备的索引。对应于 torch_xla.devices() 中的索引。

返回

一个 XLA torch.device

torch_xla.devices() List[device][source]

返回当前进程中可用的所有设备。

返回

一个 XLA torch.devices 列表。

torch_xla.device_count() int[source]

返回当前进程中可寻址设备的数量。

torch_xla.sync()[source]

启动所有挂起的图操作。

torch_xla.step()[source]

包装应分派到运行时的代码。

实验性:xla.step 仍在开发中。某些当前可以使用 xla.step 但不遵循最佳实践的代码将在将来的版本中变为错误。有关上下文,请参阅https://github.com/pytorch/xla/issues/6751

runtime

torch_xla.runtime.device_type() Optional[str][source]

返回当前 PjRt 设备类型。

如果尚未配置设备,则选择默认设备

torch_xla.runtime.local_process_count() int[source]

返回在此主机上运行的进程数。

torch_xla.runtime.local_device_count() int[source]

返回此主机上的设备总数。

假设每个进程具有相同数量的可寻址设备。

torch_xla.runtime.addressable_device_count() int[source]

返回此进程可见的设备数量。

torch_xla.runtime.global_device_count() int[source]

返回所有进程/主机上的设备总数。

torch_xla.runtime.global_runtime_device_count() int[source]

返回所有进程/主机上的运行时设备总数,尤其适用于SPMD。

torch_xla.runtime.world_size() int[source]

返回参与作业的进程总数。

torch_xla.runtime.global_ordinal() int[source]

返回此线程在所有进程中的全局序号。

全局序号范围为 [0, global_device_count)。全局序号不保证与TPU工作程序ID有任何可预测的关系,也不保证在每个主机上都是连续的。

torch_xla.runtime.local_ordinal() int[source]

返回此线程在此主机中的本地序号。

本地序号范围为 [0, local_device_count)。

torch_xla.runtime.get_master_ip() str[source]

检索运行时的主工作程序IP。此调用进入后端特定的发现API。

返回主工作程序的IP地址,表示为字符串。

torch_xla.runtime.use_spmd(auto: Optional[bool] = False)[source]
torch_xla.runtime.is_spmd()[source]

返回是否为执行设置了SPMD。

torch_xla.runtime.initialize_cache(path: str, readonly: bool = False)[source]

初始化持久编译缓存。此API必须在执行任何计算之前调用。

参数
  • path – 存储持久缓存的路径。

  • readonly – 此工作程序是否应该对缓存具有写访问权限。

xla_model

torch_xla.core.xla_model.xla_device(n=None, devkind=None)[source]

返回给定 XLA 设备的实例。

参数
  • n (python:int, optional) – 要返回的特定实例(序号)。如果指定,则将返回特定的XLA设备实例。否则将返回devkind的第一个设备。

  • devkind (string..., optional) – 如果指定,则为设备类型,例如TPUCUDACPU或自定义PJRT设备。已弃用。

返回

具有请求实例的torch.device

torch_xla.core.xla_model.xla_device_hw(device)[source]

返回给定设备的硬件类型。

参数

device (string or torch.device) – 将映射到真实设备的xla设备。

返回

给定设备的硬件类型的字符串表示形式。

torch_xla.core.xla_model.is_master_ordinal(local=True)[source]

检查当前进程是否为主序号(0)。

参数

local (bool) – 是否应检查本地或全局主序号。在多主机复制的情况下,只有一个全局主序号(主机0,设备0),而有多个NUM_HOSTS本地主序号。默认值:True

返回

一个布尔值,指示当前进程是否为主序号。

torch_xla.core.xla_model.all_reduce(reduce_type, inputs, scale=1.0, groups=None, pin_layout=True)[source]

对输入张量(s)执行就地归约操作。

参数
  • reduce_type (string) – xm.REDUCE_SUMxm.REDUCE_MULxm.REDUCE_ANDxm.REDUCE_ORxm.REDUCE_MINxm.REDUCE_MAX之一。

  • inputs – 要执行所有归约操作的单个torch.Tensortorch.Tensor列表。

  • scale (python:float) – 归约后要应用的默认缩放值。默认值:1.0

  • groups (list, optional) –

    一个列表的列表,表示all_reduce()操作的副本组。例如:[[0, 1, 2, 3], [4, 5, 6, 7]]

    定义两个组,一个具有[0, 1, 2, 3]副本,一个具有[4, 5, 6, 7]副本。如果为None,则将只有一个组,其中包含所有副本。

  • pin_layout (bool, optional) – 是否为此通信操作固定布局。布局固定可以防止参与通信的每个进程具有略微不同的程序时可能出现的数据损坏,但可能会导致某些xla编译失败。当您看到类似“HloModule has a mix of layout constrained”的错误消息时,请取消固定布局。

返回

如果传递单个torch.Tensor,则返回值为一个torch.Tensor,其中包含归约值(跨副本)。如果传递列表/元组,则此函数对输入张量执行就地全归约操作,并返回列表/元组本身。

torch_xla.core.xla_model.all_gather(value, dim=0, groups=None, output=None, pin_layout=True)[source]

沿给定维度执行所有收集操作。

参数
  • value (torch.Tensor) – 输入张量。

  • dim (python:int) – 收集维度。默认值:0

  • groups (list, optional) –

    一个列表的列表,表示 all_gather() 操作的副本组。例如: [[0, 1, 2, 3], [4, 5, 6, 7]]

    定义两个组,一个具有[0, 1, 2, 3]副本,一个具有[4, 5, 6, 7]副本。如果为None,则将只有一个组,其中包含所有副本。

  • 输出 (torch.Tensor) – 可选的输出张量。

  • pin_layout (bool, optional) – 是否为此通信操作固定布局。布局固定可以防止参与通信的每个进程具有略微不同的程序时可能出现的数据损坏,但可能会导致某些xla编译失败。当您看到类似“HloModule has a mix of layout constrained”的错误消息时,请取消固定布局。

返回

一个张量,在 dim 维度上包含所有参与副本的值。

torch_xla.core.xla_model.all_to_all(value, split_dimension, concat_dimension, split_count, groups=None, pin_layout=True)[source]

对输入张量执行 XLA AllToAll() 操作。

参见: https://tensorflowcn.cn/xla/operation_semantics#alltoall

参数
  • value (torch.Tensor) – 输入张量。

  • split_dimension (python:int) – 应该进行分割的维度。

  • concat_dimension (python:int) – 应该进行连接的维度。

  • split_count (python:int) – 分割次数。

  • groups (list, optional) –

    一个列表的列表,表示all_reduce()操作的副本组。例如:[[0, 1, 2, 3], [4, 5, 6, 7]]

    定义两个组,一个具有[0, 1, 2, 3]副本,一个具有[4, 5, 6, 7]副本。如果为None,则将只有一个组,其中包含所有副本。

  • pin_layout (bool, optional) – 是否为此通信操作固定布局。布局固定可以防止参与通信的每个进程具有略微不同的程序时可能出现的数据损坏,但可能会导致某些xla编译失败。当您看到类似“HloModule has a mix of layout constrained”的错误消息时,请取消固定布局。

返回

all_to_all() 操作的结果 torch.Tensor

torch_xla.core.xla_model.add_step_closure(closure, args=(), run_async=False)[source]

将一个闭包添加到步骤结束时要运行的闭包列表中。

在模型训练期间,许多时候需要打印/报告(打印到控制台、发布到 TensorBoard 等)需要检查中间张量内容的信息。在模型代码的不同点检查不同张量的内容需要多次执行,通常会导致性能问题。添加步骤闭包将确保它在屏障之后运行,此时所有活动张量都将已具体化为设备数据。活动张量将包括闭包参数捕获的张量。因此,使用 add_step_closure() 将确保即使在排队多个闭包、需要检查多个张量时,也只会执行一次。步骤闭包将按照其排队的顺序依次运行。请注意,即使使用此 API 优化了执行,也建议每隔 N 步节流一次打印/报告事件。

参数
  • closure (callable) – 要调用的函数。

  • args (tuple) – 要传递给闭包的参数。

  • run_async – 如果为 True,则异步运行闭包。

torch_xla.core.xla_model.wait_device_ops(devices=[])[source]

等待给定设备上的所有异步操作完成。

参数

devices (string..., optional) – 需要等待其异步操作的设备。如果为空,则将等待所有本地设备。

torch_xla.core.xla_model.optimizer_step(optimizer, barrier=False, optimizer_args={}, groups=None, pin_layout=True)[source]

运行提供的优化器步骤并发出 XLA 设备步骤计算。

参数
  • optimizer (torch.Optimizer) – 需要调用其 step() 函数的 torch.Optimizer 实例。 step() 函数将使用 optimizer_args 命名参数进行调用。

  • barrier (bool, optional) – 是否应在此 API 中发出 XLA 张量屏障。如果使用 PyTorch XLA ParallelLoaderDataParallel 支持,则不需要此操作,因为屏障将由 XLA 数据加载器迭代器 next() 调用发出。默认值:False

  • optimizer_args (dict, optional) – 用于 optimizer.step() 调用的命名参数字典。

  • groups (list, optional) –

    一个列表的列表,表示all_reduce()操作的副本组。例如:[[0, 1, 2, 3], [4, 5, 6, 7]]

    定义两个组,一个具有[0, 1, 2, 3]副本,一个具有[4, 5, 6, 7]副本。如果为None,则将只有一个组,其中包含所有副本。

  • pin_layout (bool, optional) – 在减少梯度时是否固定布局。有关详细信息,请参阅 xm.all_reduce

返回

optimizer.step() 调用返回的相同值。

torch_xla.core.xla_model.save(data, file_or_path, master_only=True, global_master=False)[source]

将输入数据保存到文件中。

保存的数据在保存之前会传输到 PyTorch CPU 设备,因此后续的 torch.load() 将加载 CPU 数据。在使用视图时必须小心。建议不要保存视图,而是在张量加载并移动到其目标设备后重新创建它们。

参数
  • data – 要保存的输入数据。Python 对象(列表、元组、集合、字典等)的任何嵌套组合。

  • file_or_path – 数据保存操作的目标位置。可以是文件路径或 Python 文件对象。如果 master_onlyFalse,则路径或文件对象必须指向不同的目标,否则来自同一主机的所有写入将相互覆盖。

  • master_only (bool, optional) – 是否只有主设备应保存数据。如果为 False,则 file_or_path 参数对于参与复制的每个序号都应是不同的文件或路径,否则同一主机上的所有副本都将写入同一位置。默认值:True

  • global_master (bool, optional) – 当 master_onlyTrue 时,此标志控制每个主机的 master(如果 global_masterFalse)是否保存内容,或者只有全局 master(序号 0)保存内容。默认值:False

  • sync (bool, optional) – 是否在保存张量后同步所有副本。如果为 True,则所有副本都必须调用 xm.save,否则主进程将挂起。

torch_xla.core.xla_model.rendezvous(tag, payload=b'', replicas=[])[source]

等待所有网格客户端到达指定的 rendezvous 点。

注意:PJRT 不支持 XRT 网格服务器,因此这实际上是 xla_rendezvous 的别名。

参数
  • tag (string) – 要加入的 rendezvous 的名称。

  • payload (bytes, optional) – 要发送到 rendezvous 的有效负载。

  • replicas (list, python:int) – 参与 rendezvous 的副本序号。空表示网格中的所有副本。默认值:[]

返回

所有其他内核交换的有效负载,其中内核序号 i 的有效负载位于返回元组的 i 位置。

torch_xla.core.xla_model.mesh_reduce(tag, data, reduce_fn)[source]

执行图外客户端网格归约。

参数
  • tag (string) – 要加入的 rendezvous 的名称。

  • data – 要归约的数据。 reduce_fn 可调用对象将接收一个列表,其中包含来自所有网格客户端进程(每个核心一个)的相同数据的副本。

  • reduce_fn (callable) – 一个函数,它接收一个 data 类对象的列表并返回归约结果。

返回

归约值。

torch_xla.core.xla_model.set_rng_state(seed, device=None)[source]

设置随机数生成器状态。

参数
  • seed (python:integer) – 要设置的状态。

  • device (string, optional) – 需要设置 RNG 状态的设备。如果缺少,则将设置默认设备种子。

torch_xla.core.xla_model.get_rng_state(device=None)[source]

获取当前运行的随机数生成器状态。

参数

device (string, optional) – 需要检索其 RNG 状态的设备。如果缺少,则将设置默认设备种子。

返回

RNG 状态,作为整数。

torch_xla.core.xla_model.get_memory_info(device: device) MemoryInfo[source]

检索设备内存使用情况。

参数

device – 需要获取内存信息的设备。

返回

包含给定设备内存使用情况的 MemoryInfo 字典。

torch_xla.core.xla_model.get_stablehlo(tensors=None) str[source]

以字符串格式获取计算图的 StableHLO。

如果 tensors 不为空,则会转储以 tensors 作为输出的图。如果 tensors 为空,则会转储整个计算图。TODO(lsy323): 当 tensors 为空时,一些中间张量也会作为输出转储。需要进一步调查。

对于推理图,建议将模型输出传递给 tensors。对于训练图,识别“输出”并不简单。建议使用空的 tensors

要在 StableHLO 中启用源代码行信息,请设置环境变量 XLA_HLO_DEBUG=1。

参数

tensors (list[torch.Tensor], optional) – 表示 StableHLO 图的输出/根的张量。

返回

字符串格式的 StableHLO 模块。

torch_xla.core.xla_model.get_stablehlo_bytecode(tensors=None) bytes[source]

以字节码格式获取计算图的 StableHLO。

如果 tensors 不为空,则会转储以 tensors 作为输出的图。如果 tensors 为空,则会转储整个计算图。TODO(lsy323): 当 tensors 为空时,一些中间张量也会作为输出转储。需要进一步调查。

对于推理图,建议将模型输出传递给 tensors。对于训练图,识别“输出”并不简单。建议使用空的 tensors

参数

tensors (list[torch.Tensor], optional) – 表示 StableHLO 图的输出/根的张量。

返回

字节码格式的 StableHLO 模块。

分布式

class torch_xla.distributed.parallel_loader.ParallelLoader(loader, devices, batchdim=0, batches_per_execution=1, loader_prefetch_size=8, device_prefetch_size=4, host_to_device_transfer_threads=1, input_sharding=None)[source]

使用后台数据上传包装现有的 PyTorch DataLoader。

参数
  • loader (torch.utils.data.DataLoader) – 要包装的 PyTorch DataLoader。

  • devices (torch.device…) – 数据需要发送到的设备列表。由 loader 返回的第 i 个样本将发送到 devices[i % len(devices)]

  • batchdim (python:int, optional) – 包含批大小的维度。默认值:0

  • loader_prefetch_size (python:int, optional) – 从 loader 读取样本的线程使用的队列的最大容量,以便由将数据上传到设备的工作线程进行处理。默认值:8

  • device_prefetch_size (python:int, optional) – 每个设备队列的最大大小,工作线程将已发送到设备的张量存放在其中。默认值:4

  • host_to_device_transfer_threads (python:int, optional) – 并行工作以将数据从加载器队列传输到设备队列的线程数。默认值:1

  • input_sharding (ShardingSpec, optional) – 加载后应用于兼容输入张量的分片规范。默认值:None

per_device_loader(device)[source]

检索给定设备的加载器迭代器对象。

参数

device (torch.device) – 请求加载器的设备。

返回

给定 device 的加载器迭代器对象。这不是 torch.utils.data.DataLoader 接口,而是一个 Python 迭代器,它返回与包装的 torch.utils.data.DataLoader 返回的相同张量数据结构,但位于 XLA 设备上。

torch_xla.distributed.xla_multiprocessing.spawn(fn, args=(), nprocs=None, join=True, daemon=False, start_method='spawn')[source]

启用基于多处理的复制。

参数
  • fn (callable) – 对于参与复制的每个设备要调用的函数。该函数将以第一个参数作为复制中进程的全局索引,然后是 args 中传递的参数来调用。

  • args (tuple) – fn 的参数。默认值:空元组

  • nprocs (python:int) – 复制的进程/设备数量。目前,如果指定,可以是 1 或最大设备数量。

  • join (bool) – 调用是否应阻塞并等待已生成的进程完成。默认值:True

  • daemon (bool) – 已生成的进程是否应设置 daemon 标志(参见 Python 多处理 API)。默认值:False

  • start_method (string) – Python multiprocessing 进程创建方法。默认值:spawn

返回

torch.multiprocessing.spawn API 返回的相同对象。如果 nprocs 为 1,则将直接调用 fn 函数,并且 API 将返回 None。

spmd

torch_xla.distributed.spmd.mark_sharding(t: Union[Tensor, XLAShardedTensor], mesh: Mesh, partition_spec: Tuple[Optional[Union[Tuple, int, str]]], use_dynamo_custom_op: bool = False) XLAShardedTensor[source]

使用 XLA 分区规范注释提供的张量。在内部,它将相应的 XLATensor 注释为针对 XLA SpmdPartitioner 传递进行分片。:param t: 要使用 partition_spec 注释的输入张量。:type t: Union[torch.Tensor, XLAShardedTensor] :param mesh: 描述逻辑 XLA 设备拓扑和底层设备 ID。:type mesh: Mesh :param partition_spec: 设备网格维度索引的元组或

None。每个索引都是一个 int,如果网格轴已命名,则为 str,或 int 或 str 的元组。这指定了每个输入秩如何分片(索引到 mesh_shape)或复制(None)。当指定元组时,相应的输入张量轴将沿元组中的所有逻辑轴分片。请注意,元组中指定网格轴的顺序将影响最终的分片。

参数
  • 示例 (对于) –

  • 行方向 (我们可以将 8x10 张量 4 路分片) –

  • 列方向。 (并复制) –

  • torch.randn (>> input =) –

  • = (>> partition_spec) –

  • =

  • dynamo_custom_op (布尔值) – 如果设置为 True,则调用 mark_sharding 的 Dynamo 自定义操作变体,使其能够被 Dynamo 识别和跟踪。

示例 ——————————— mesh_shape = (4, 2) num_devices = xr.global_runtime_device_count() device_ids = np.array(range(num_devices)) mesh = Mesh(device_ids, mesh_shape, (‘x’, ‘y’))

# 4 路数据并行 input = torch.randn(8, 32).to(xm.xla_device()) xs.mark_sharding(input, mesh, (0, None))

# 2 路模型并行 linear = nn.Linear(32, 10).to(xm.xla_device()) xs.mark_sharding(linear.weight, mesh, (None, 1))

torch_xla.distributed.spmd.clear_sharding(t: Union[Tensor, XLAShardedTensor]) Tensor[source]

清除输入张量的分片注释并返回一个cpu转换后的张量。

torch_xla.distributed.spmd.set_global_mesh(mesh: Mesh)[source]
torch_xla.distributed.spmd.get_global_mesh()[source]
class torch_xla.distributed.spmd.Mesh(device_ids: Union[ndarray, List], mesh_shape: Tuple[int, ...], axis_names: Optional[Tuple[str, ...]] = None)[source]

描述逻辑 XLA 设备拓扑网格和底层资源。

参数
  • device_ids (联合体[np.ndarray, 列表]) – 以自定义顺序排列的设备(ID)的扁平化列表。该列表被重新整形为一个mesh_shape数组,使用 C 样式索引顺序填充元素。

  • mesh_shape (元组[python:int, ...]) – 一个 int 元组,描述设备网格的逻辑拓扑形状,每个元素描述相应轴上的设备数量。

  • axis_names (元组[字符串, ...]) – 要分配给devices参数维度的资源轴名称序列。其长度应与devices的秩匹配。

示例: ——————————— mesh_shape = (4, 2) num_devices = len(xm.get_xla_supported_devices()) device_ids = np.array(range(num_devices)) mesh = Mesh(device_ids, mesh_shape, (‘x’, ‘y’)) mesh.get_logical_mesh() >> array([[0, 1],

[2, 3], [4, 5], [6, 7]])

mesh.shape() >> OrderedDict([(‘x’, 4), (‘y’, 2)])

class torch_xla.distributed.spmd.HybridMesh(*, ici_mesh_shape: Tuple[int, ...], dcn_mesh_shape: Optional[Tuple[int, ...]] = None, axis_names: Optional[Tuple[str, ...]] = None)[source]
创建由 ICI 和 DCN 网络连接的设备的混合设备网格。

逻辑网格的形状应按网络强度递增的顺序排列,例如 [副本、数据、模型],其中 mdl 具有最多的网络通信需求。

参数
  • ici_mesh_shape – 内部连接设备的逻辑网格形状。

  • dcn_mesh_shape – 外部连接设备的逻辑网格形状。

示例

# 此示例假设 v4-8 的 2 个切片。 ici_mesh_shape = (1, 4, 1) # (数据, fsdp, 张量) dcn_mesh_shape = (2, 1, 1)

mesh = HybridMesh(ici_mesh_shape, dcn_mesh_shape, (‘data’,’fsdp’,’tensor’)) print(mesh.shape()) >> OrderedDict([(‘data’, 2), (‘fsdp’, 4), (‘tensor’, 1)])

class torch_xla.distributed.spmd.ShardingSpec(mesh: torch_xla.distributed.spmd.xla_sharding.Mesh, partition_spec: Tuple[Optional[int]], minibatch: Optional[bool] = False)[source]

实验性

torch_xla.experimental.eager_mode(enable: bool)[source]

配置 torch_xla 的默认执行模式。

在急切模式下,只有使用`torch_xla.compile`编译过的函数才会被跟踪和编译。其他 torch 操作将以急切方式执行。

torch_xla.experimental.compile(func)[source]

使用延迟张量编译函数。

返回采用完全相同输入的优化后的函数。Compile 将在使用延迟张量的跟踪模式下运行目标函数。

调试

torch_xla.debug.metrics.metrics_report()[source]

检索包含完整指标和计数器报告的字符串。

torch_xla.debug.metrics.short_metrics_report(counter_names: Optional[列表] = None, metric_names: Optional[列表] = None)[source]

检索包含完整指标和计数器报告的字符串。

参数
  • counter_names (列表) – 需要打印其数据的计数器名称列表。

  • metric_names (列表) – 需要打印其数据的指标名称列表。

torch_xla.debug.metrics.counter_names()[source]

检索所有当前活动的计数器名称。

torch_xla.debug.metrics.counter_value(name)[source]

返回活动计数器的值。

参数

name (字符串) – 需要检索其值的计数器的名称。

返回

计数器值作为整数。

torch_xla.debug.metrics.metric_names()[source]

检索所有当前活动的指标名称。

torch_xla.debug.metrics.metric_data(name)[source]

返回活动指标的数据。

参数

name (字符串) – 需要检索其数据的指标的名称。

返回

指标数据,它是一个 (TOTAL_SAMPLES, ACCUMULATOR, SAMPLES) 的元组。 TOTAL_SAMPLES 是已发布到指标的样本总数。指标仅保留给定数量的样本(在循环缓冲区中)。ACCUMULATORTOTAL_SAMPLES 上样本的总和。SAMPLES 是 (TIME, VALUE) 元组的列表。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源