快捷方式

PyTorch/XLA API

torch_xla

torch_xla.device(index: Optional[int] = None) device[source]

返回指定 XLA 设备实例。

如果启用 SPMD,则返回一个虚拟设备,该设备封装了此进程可用的所有设备。

参数

index – 要返回的 XLA 设备的索引。对应于 torch_xla.devices() 中的索引。

返回

一个 XLA torch.device 对象。

torch_xla.devices() List[device][source]

返回当前进程中所有可用的设备。

返回

一个包含 XLA torch.devices 对象的列表。

torch_xla.device_count() int[source]

返回当前进程中可寻址的设备数量。

torch_xla.sync(wait: bool = False)[source]

启动所有挂起的图操作。

参数

wait (bool) – 是否阻塞当前进程直到执行完成。

torch_xla.compile(f: Optional[Callable] = None, full_graph: Optional[bool] = False, name: Optional[str] = None, num_different_graphs_allowed: Optional[int] = None)[source]

使用 torch_xla 的 LazyTensor 跟踪模式优化给定的模型/函数。PyTorch/XLA 将使用给定的输入跟踪给定的函数,然后生成图来表示函数内部发生的 PyTorch 操作。此图将由 XLA 编译并在加速器(由张量的设备决定)上执行。对于函数的编译区域,即时模式将被禁用。

参数
  • model (Callable) – 要优化的模块/函数,如果未传入,此函数将作为上下文管理器使用。

  • full_graph (Optional[bool]) – 此编译是否应生成一个单一图。如果设置为 True 并且将生成多个图,torch_xla 将抛出错误并带调试信息退出。

  • name (Optional[name]) – 编译程序的名称。如果未指定,将使用函数 f 的名称。此名称将用于 PT_XLA_DEBUG 消息以及 HLO/IR dump 文件中。

  • num_different_graphs_allowed (Optional[python:int]) – 允许拥有的给定模型/函数的不同跟踪图数量。如果超出此限制,将引发错误。

示例

# usage 1
@torch_xla.compile()
def foo(x):
  return torch.sin(x) + torch.cos(x)

def foo2(x):
  return torch.sin(x) + torch.cos(x)
# usage 2
compiled_foo2 = torch_xla.compile(foo2)

# usage 3
with torch_xla.compile():
  res = foo2(x)
torch_xla.manual_seed(seed, device=None)[source]

为当前 XLA 设备设置生成随机数的种子。

参数
  • seed (python:integer) – 要设置的状态。

  • device (torch.device, optional) – 需要设置 RNG 状态的设备。如果省略,将设置默认设备的种子。

runtime

torch_xla.runtime.device_type() Optional[str][source]

返回当前的 PjRt 设备类型。

如果未配置任何设备,则选择一个默认设备

返回

设备的字符串表示形式。

torch_xla.runtime.local_process_count() int[source]

返回在此主机上运行的进程数量。

torch_xla.runtime.local_device_count() int[source]

返回此主机上的设备总数。

假设每个进程拥有相同数量的可寻址设备。

torch_xla.runtime.addressable_device_count() int[source]

返回此进程可见的设备数量。

torch_xla.runtime.global_device_count() int[source]

返回所有进程/主机上的设备总数。

torch_xla.runtime.global_runtime_device_count() int[source]

返回所有进程/主机上的运行时设备总数,这对于 SPMD 特别有用。

torch_xla.runtime.world_size() int[source]

返回参与作业的进程总数。

torch_xla.runtime.global_ordinal() int[source]

返回此线程在所有进程中的全局序号。

全局序号在 [0, global_device_count) 范围内。全局序号与 TPU worker ID 之间不保证有任何可预测的关系,也不保证在每个主机上是连续的。

torch_xla.runtime.local_ordinal() int[source]

返回此线程在此主机内的本地序号。

本地序号在 [0, local_device_count) 范围内。

torch_xla.runtime.get_master_ip() str[source]

检索运行时的 master worker IP。这会调用后端特定的发现 API。

返回

master worker 的 IP 地址,以字符串形式表示。

torch_xla.runtime.use_spmd(auto: Optional[bool] = False)[source]

用于启用 SPMD 模式的 API。这是启用 SPMD 的推荐方法。

如果某些张量已经在非 SPMD 设备上初始化,这将强制进入 SPMD 模式。这意味着这些张量将在设备间进行复制。

参数

auto (*bool*) – 是否启用自动分片。阅读 https://github.com/pytorch/xla/blob/master/docs/spmd_advanced.md#auto-sharding 获取更多详情。

torch_xla.runtime.is_spmd()[source]

返回是否已为执行设置 SPMD。

torch_xla.runtime.initialize_cache(path: str, readonly: bool = False)[source]

初始化持久化编译缓存。此 API 必须在执行任何计算之前调用。

参数
  • path (*str*) – 存储持久化缓存的路径。

  • readonly (*bool*) – 此 worker 是否应具有缓存的写访问权限。

xla_model

torch_xla.core.xla_model.xla_device(n: Optional[int] = None, devkind: Optional[str] = None) device[source]

返回指定 XLA 设备实例。

参数
  • n (*python:int*, *optional*) – 要返回的特定实例(序号)。如果指定,将返回特定的 XLA 设备实例。否则将返回 devkind 的第一个设备。

  • devkind (*string...*, *optional*) – 如果指定,则为设备类型,例如 TPUCUDACPU 或自定义 PJRT 设备。已弃用。

返回

具有请求实例的 torch.device 对象。

torch_xla.core.xla_model.xla_device_hw(device: Union[str, device]) str[source]

返回给定设备的硬件类型。

参数

device (*string* 或 *torch.device*) – 将映射到实际设备的 xla 设备。

返回

给定设备的硬件类型的字符串表示形式。

torch_xla.core.xla_model.is_master_ordinal(local: bool = True) bool[source]

检查当前进程是否是 master 序号 (0)。

参数

local (*bool*) – 是否检查本地或全局 master 序号。在多主机复制的情况下,只有一个全局 master 序号(主机 0,设备 0),而有 NUM_HOSTS 个本地 master 序号。默认值:True

返回

一个布尔值,指示当前进程是否是 master 序号。

torch_xla.core.xla_model.all_reduce(reduce_type: str, inputs: Union[Tensor, List[Tensor]], scale: float = 1.0, groups: Optional[List[List[int]]] = None, pin_layout: bool = True) Union[Tensor, List[Tensor]][source]

对输入张量执行原地规约操作。

参数
  • reduce_type (*string*) – 以下之一:xm.REDUCE_SUMxm.REDUCE_MULxm.REDUCE_ANDxm.REDUCE_ORxm.REDUCE_MINxm.REDUCE_MAX

  • inputs – 一个单独的 torch.Tensor 或一个 torch.Tensor 列表,用于执行 all reduce 操作。

  • scale (*python:float*) – 在规约后应用的默认缩放值。默认值:1.0

  • groups (*list*, *optional*) –

    一个列表的列表,表示 all_reduce() 操作的副本组。示例:[[0, 1, 2, 3], [4, 5, 6, 7]]

    定义了两个组,一个包含 [0, 1, 2, 3] 副本,另一个包含 [4, 5, 6, 7] 副本。如果为 None,则只有一个包含所有副本的组。

  • pin_layout (*bool*, *optional*) – 是否为此通信操作固定布局。当参与通信的每个进程具有稍微不同的程序时,固定布局可以防止潜在的数据损坏,但可能会导致一些 xla 编译失败。当看到类似于“HloModule has a mix of layout constrained”的错误消息时,请取消固定布局。

返回

如果传入单个 torch.Tensor,返回值是一个 torch.Tensor,其中包含规约后的值(跨副本)。如果传入列表/元组,此函数将对输入张量执行原地 all-reduce 操作,并返回列表/元组本身。

torch_xla.core.xla_model.all_gather(value: Tensor, dim: int = 0, groups: Optional[List[List[int]]] = None, output: Optional[Tensor] = None, pin_layout: bool = True) Tensor[source]

沿给定维度执行 all-gather 操作。

参数
  • value (torch.Tensor) – 输入张量。

  • dim (python:int) – 聚集维度。默认值: 0

  • groups (*list*, *optional*) –

    一个列表的列表,表示 all_gather() 操作的副本组。示例:[[0, 1, 2, 3], [4, 5, 6, 7]]

    定义了两个组,一个包含 [0, 1, 2, 3] 副本,另一个包含 [4, 5, 6, 7] 副本。如果为 None,则只有一个包含所有副本的组。

  • output (torch.Tensor) – 可选的输出张量。

  • pin_layout (*bool*, *optional*) – 是否为此通信操作固定布局。当参与通信的每个进程具有稍微不同的程序时,固定布局可以防止潜在的数据损坏,但可能会导致一些 xla 编译失败。当看到类似于“HloModule has a mix of layout constrained”的错误消息时,请取消固定布局。

返回

一个张量,其 dim 维度包含所有参与副本的值。

torch_xla.core.xla_model.all_to_all(value: Tensor, split_dimension: int, concat_dimension: int, split_count: int, groups: Optional[List[List[int]]] = None, pin_layout: bool = True) Tensor[source]

对输入张量执行 XLA AllToAll() 操作。

参阅:https://tensorflowcn.cn/xla/operation_semantics#alltoall

参数
  • value (torch.Tensor) – 输入张量。

  • split_dimension (python:int) – 进行分割的维度。

  • concat_dimension (python:int) – 进行拼接的维度。

  • split_count (python:int) – 分割计数。

  • groups (*list*, *optional*) –

    一个列表的列表,表示 all_reduce() 操作的副本组。示例:[[0, 1, 2, 3], [4, 5, 6, 7]]

    定义了两个组,一个包含 [0, 1, 2, 3] 副本,另一个包含 [4, 5, 6, 7] 副本。如果为 None,则只有一个包含所有副本的组。

  • pin_layout (*bool*, *optional*) – 是否为此通信操作固定布局。当参与通信的每个进程具有稍微不同的程序时,固定布局可以防止潜在的数据损坏,但可能会导致一些 xla 编译失败。当看到类似于“HloModule has a mix of layout constrained”的错误消息时,请取消固定布局。

返回

all_to_all() 操作的结果 torch.Tensor

torch_xla.core.xla_model.add_step_closure(closure: Callable[[...], Any], args: Tuple[Any, ...] = (), run_async: bool = False)[source]

将一个闭包添加到步骤结束时运行的闭包列表。

在模型训练过程中,经常需要在中间张量的内容被检查后打印/报告信息(打印到控制台,发布到 tensorboard 等)。在模型代码的不同点检查不同张量的内容需要多次执行,通常会导致性能问题。添加步骤闭包将确保它在屏障之后运行,此时所有存活的张量都已具体化到设备数据。存活的张量将包括闭包参数捕获的张量。因此,使用 add_step_closure() 将确保只执行一次,即使有多个闭包排队,需要检查多个张量。步骤闭包将按排队顺序依次运行。请注意,即使使用此 API 可优化执行,仍建议每 N 个步骤限制一次打印/报告事件的频率。

参数
  • closure (callable) – 要调用的函数。

  • args (tuple) – 要传递给闭包的参数。

  • run_async – 如果为 True,则异步运行闭包。

torch_xla.core.xla_model.wait_device_ops(devices: List[str] = [])[source]

等待给定设备上的所有异步操作完成。

参数

devices (string..., optional) – 需要等待其异步操作的设备。如果为空,则等待所有本地设备。

torch_xla.core.xla_model.optimizer_step(optimizer: Optimizer, barrier: bool = False, optimizer_args: Dict = {}, groups: Optional[List[List[int]]] = None, pin_layout: bool = True)[source]

运行提供的优化器步骤并在所有设备上同步梯度。

参数
  • optimizer (torch.Optimizer) – 需要调用其 step() 函数的 torch.Optimizer 实例。step() 函数将使用 optimizer_args 命名参数调用。

  • barrier (bool, optional) – 是否在此 API 中发出 XLA 张量屏障。如果使用 PyTorch XLA ParallelLoaderDataParallel 支持,则不是必需的,因为屏障将由 XLA 数据加载器迭代器的 next() 调用发出。默认值: False

  • optimizer_args (dict, optional) – optimizer.step() 调用的命名参数字典。

  • groups (*list*, *optional*) –

    一个列表的列表,表示 all_reduce() 操作的副本组。示例:[[0, 1, 2, 3], [4, 5, 6, 7]]

    定义了两个组,一个包含 [0, 1, 2, 3] 副本,另一个包含 [4, 5, 6, 7] 副本。如果为 None,则只有一个包含所有副本的组。

  • pin_layout (bool, optional) – 在规约梯度时是否固定布局。详细信息请参阅 xm.all_reduce

返回

optimizer.step() 调用返回的值相同。

示例

>>> import torch_xla.core.xla_model as xm
>>> xm.optimizer_step(self.optimizer)
torch_xla.core.xla_model.save(data: Any, file_or_path: Union[str, TextIO], master_only: bool = True, global_master: bool = False)[source]

将输入数据保存到文件。

保存的数据在保存之前会被传输到 PyTorch CPU 设备,因此后续的 torch.load() 将加载 CPU 数据。处理视图时必须小心。建议在张量加载并移动到其目标设备后重新创建视图,而不是保存视图。

参数
  • data – 要保存的输入数据。可以是 Python 对象的任何嵌套组合(列表、元组、集合、字典等)。

  • file_or_path – 数据保存操作的目标。可以是文件路径或 Python 文件对象。如果 master_onlyFalse,则路径或文件对象必须指向不同的目标,否则同一主机的所有写入将相互覆盖。

  • master_only (bool, optional) – 是否仅主设备保存数据。如果为 False,则 file_or_path 参数对于参与复制的每个序号都应是不同的文件或路径,否则同一主机上的所有副本将写入同一位置。默认值: True

  • global_master (bool, optional) – 当 master_onlyTrue 时,此标志控制是每个主机的 master(如果 global_masterFalse)保存内容,还是仅全局 master(序号 0)保存。默认值: False

示例

>>> import torch_xla.core.xla_model as xm
>>> xm.wait_device_ops() # wait for all pending operations to finish.
>>> xm.save(obj_to_save, path_to_save)
>>> xm.rendezvous('torch_xla.core.xla_model.save') # multi process context only
torch_xla.core.xla_model.rendezvous(tag: str, payload: bytes = b'', replicas: List[int] = []) List[bytes][source]

等待所有网格客户端到达指定的会合点。

注意:PJRT 不支持 XRT 网格服务器,因此这实际上是 xla_rendezvous 的别名。

参数
  • tag (string) – 要加入的会合点的名称。

  • payload (bytes, optional) – 要发送到会合点的载荷。

  • replicas (list, python:int) – 参与会合的副本序号。空列表表示网格中的所有副本。默认值: []

返回

所有其他核心交换的载荷列表,核心序号 i 的载荷位于返回列表的第 i 个位置。

示例

>>> import torch_xla.core.xla_model as xm
>>> xm.rendezvous('example')
torch_xla.core.xla_model.mesh_reduce(tag: str, data, reduce_fn: Callable[[...], Any]) Union[Any, ToXlaTensorArena][source]

执行图外客户端网格规约。

参数
  • tag (string) – 要加入的会合点的名称。

  • data – 要规约的数据。reduce_fn 可调用对象将接收一个列表,其中包含来自所有网格客户端进程(每个核心一个)的相同数据的副本。

  • reduce_fn (callable) – 一个函数,接收一个类 data 对象的列表并返回规约结果。

返回

规约后的值。

示例

>>> import torch_xla.core.xla_model as xm
>>> import numpy as np
>>> accuracy = xm.mesh_reduce('test_accuracy', accuracy, np.mean)
torch_xla.core.xla_model.set_rng_state(seed: int, device: Optional[str] = None)[source]

设置随机数生成器状态。

参数
  • seed (python:integer) – 要设置的状态。

  • device (string, optional) – 需要设置 RNG 状态的设备。如果缺失,将设置默认设备种子。

torch_xla.core.xla_model.get_rng_state(device: Optional[str] = None) int[source]

获取当前运行的随机数生成器状态。

参数

device (string, optional) – 需要检索其 RNG 状态的设备。如果缺失,将设置默认设备种子。

返回

RNG 状态,以整数形式。

torch_xla.core.xla_model.get_memory_info(device: Optional[device] = None) MemoryInfo[source]

检索设备内存使用情况。

参数
  • device – Optional[torch.device] 需要其内存信息的设备。

  • device. (If not passed will use the default) – 如果未传递,将使用默认设备。

返回

带有给定设备内存使用情况的 MemoryInfo 字典。

示例

>>> xm.get_memory_info()
{'bytes_used': 290816, 'bytes_limit': 34088157184, 'peak_bytes_used': 500816}
torch_xla.core.xla_model.get_stablehlo(tensors: Optional[List[Tensor]] = None) str[source]

获取计算图的 StableHLO 字符串格式。

如果 tensors 不为空,将转储以 tensors 为输出的图。如果 tensors 为空,将转储整个计算图。

对于推理图,建议将模型输出传递给 tensors。对于训练图,识别“输出”并不直接。建议使用空的 tensors

要在 StableHLO 中启用源代码行信息,请设置环境变量 XLA_HLO_DEBUG=1。

参数

tensors (list[torch.Tensor], optional) – 代表 StableHLO 图的输出/根张量。

返回

StableHLO Module 的字符串格式。

torch_xla.core.xla_model.get_stablehlo_bytecode(tensors: Optional[Tensor] = None) bytes[source]

获取计算图的 StableHLO 字节码格式。

如果 tensors 不为空,将转储以 tensors 为输出的图。如果 tensors 为空,将转储整个计算图。

对于推理图,建议将模型输出传递给 tensors。对于训练图,识别“输出”并不直接。建议使用空的 tensors

参数

tensors (list[torch.Tensor], optional) – 代表 StableHLO 图的输出/根张量。

返回

StableHLO Module 的字节码格式。

分布式

class torch_xla.distributed.parallel_loader.MpDeviceLoader(loader, device, **kwargs)[source]

使用后台数据上传包装现有 PyTorch DataLoader。

此类仅应与多进程数据并行一起使用。它将使用 ParallelLoader 包装传入的 dataloader 并返回当前设备的 per_device_loader。

参数
  • loader (torch.utils.data.DataLoader) – 要包装的 PyTorch DataLoader。

  • device (torch.device…) – 必须将数据发送到的设备。

  • kwargsParallelLoader 构造函数的命名参数。

示例

>>> device = torch_xla.device()
>>> train_device_loader = MpDeviceLoader(train_loader, device)
torch_xla.distributed.xla_multiprocessing.spawn(fn, args=(), nprocs=None, join=True, daemon=False, start_method='spawn')[source]

启用基于多进程的复制。

参数
  • fn (callable) – 为参与复制的每个设备调用的函数。该函数将被调用,第一个参数是复制中进程的全局索引,后跟 args 中传递的参数。

  • args (tuple) – fn 的参数。默认值: 空元组

  • nprocs (python:int) – 复制的进程/设备数量。目前,如果指定,可以是 1 或 None(这将自动转换为最大设备数量)。其他数字将导致 ValueError。

  • join (bool) – 调用是否应阻塞等待已启动进程的完成。默认值: True

  • daemon (bool) – 启动的进程是否应设置 daemon 标志(参见 Python 多进程 API)。默认值: False

  • start_method (string) – Python multiprocessing 进程创建方法。默认值: spawn

返回

torch.multiprocessing.spawn API 返回的对象相同。如果 nprocs 为 1,则直接调用 fn 函数,API 将返回 None。

spmd

torch_xla.distributed.spmd.mark_sharding(t: Union[Tensor, XLAShardedTensor], mesh: Mesh, partition_spec: Tuple[Optional[Union[Tuple, int, str]], ...]) XLAShardedTensor[source]

使用 XLA 分区规范注释提供的张量。在内部,它将相应的 XLATensor 注释为分片,供 XLA SpmdPartitioner 阶段使用。

参数
  • t (Union[torch.Tensor, XLAShardedTensor]) – 要使用 partition_spec 注释的输入张量。

  • mesh (Mesh) – 描述逻辑 XLA 设备拓扑和底层设备 ID。

  • partition_spec (Tuple[Tuple, python:int, str, None]) – 设备网格维度索引或 None 的元组。每个索引是一个 int、如果网格轴被命名则是 str,或者 int 或 str 的元组。这指定了每个输入秩如何被分片(网格形状的索引)或复制(None)。指定元组时,对应的输入张量轴将沿着元组中所有逻辑轴进行分片。请注意,元组中指定网格轴的顺序将影响结果分片。

  • dynamo_custom_op (bool) – 如果设置为 True,它将调用 mark_sharding 的 dynamo 自定义操作变体,使其可被 dynamo 识别和追踪。

示例

>>> import torch_xla.runtime as xr
>>> import torch_xla.distributed.spmd as xs
>>> mesh_shape = (4, 2)
>>> num_devices = xr.global_runtime_device_count()
>>> device_ids = np.array(range(num_devices))
>>> mesh = Mesh(device_ids, mesh_shape, ('x', 'y'))
>>> input = torch.randn(8, 32).to(xm.xla_device())
>>> xs.mark_sharding(input, mesh, (0, None)) # 4-way data parallel
>>> linear = nn.Linear(32, 10).to(xm.xla_device())
>>> xs.mark_sharding(linear.weight, mesh, (None, 1)) # 2-way model parallel
torch_xla.distributed.spmd.clear_sharding(t: Union[Tensor, XLAShardedTensor]) Tensor[source]

清除输入张量的分片注释并返回转换为 cpu 的张量。这是一个就地操作,但也会返回同一个 torch.Tensor。

参数

t (Union[torch.Tensor, XLAShardedTensor]) – 我们想要清除分片的张量。

返回

未分片的张量。

Return type

t (torch.Tensor)

示例

>>> import torch_xla.distributed.spmd as xs
>>> torch_xla.runtime.use_spmd()
>>> t1 = torch.randn(8,8).to(torch_xla.device())
>>> mesh = xs.get_1d_mesh()
>>> xs.mark_sharding(t1, mesh, (0, None))
>>> xs.clear_sharding(t1)
torch_xla.distributed.spmd.set_global_mesh(mesh: Mesh)[source]

设置可用于当前进程的全局网格。

参数

mesh – (Mesh) 将成为全局网格的 Mesh 对象。

示例

>>> import torch_xla.distributed.spmd as xs
>>> mesh = xs.get_1d_mesh("data")
>>> xs.set_global_mesh(mesh)
torch_xla.distributed.spmd.get_global_mesh() Optional[Mesh][source]

获取当前进程的全局网格。

返回

(Optional[Mesh]) 如果已设置全局网格,则为 Mesh 对象,否则返回 None。

Return type

mesh

示例

>>> import torch_xla.distributed.spmd as xs
>>> xs.get_global_mesh()
torch_xla.distributed.spmd.get_1d_mesh(axis_name: Optional[str] = None) Mesh[source]

一个辅助函数,返回所有设备位于同一维度的网格。

参数

axis_name – (Optional[str]) 可选字符串,表示网格的轴名称

返回

Mesh 对象

Return type

Mesh

示例

>>> # This example is assuming 1 TPU v4-8
>>> import torch_xla.distributed.spmd as xs
>>> mesh = xs.get_1d_mesh("data")
>>> print(mesh.mesh_shape)
(4,)
>>> print(mesh.axis_names)
('data',)
class torch_xla.distributed.spmd.Mesh(device_ids: Union[ndarray, List], mesh_shape: Tuple[int, ...], axis_names: Optional[Tuple[str, ...]] = None)[source]

描述逻辑 XLA 设备拓扑网格和底层资源。

参数
  • device_ids (Union[np.ndarray, List]) – 按自定义顺序排列的设备(ID)展平列表。该列表将被重塑为 mesh_shape 数组,并使用类似 C 语言的索引顺序填充元素。

  • mesh_shape (Tuple[python:int, ...]) – 一个整数元组,描述设备网格的逻辑拓扑形状,且每个元素描述了相应轴上的设备数量。

  • axis_names (Tuple[str, ...]) – 资源轴名称的序列,将分配给 devices 参数的维度。其长度应与 devices 的秩匹配。

示例

>>> mesh_shape = (4, 2)
>>> num_devices = len(xm.get_xla_supported_devices())
>>> device_ids = np.array(range(num_devices))
>>> mesh = Mesh(device_ids, mesh_shape, ('x', 'y'))
>>> mesh.get_logical_mesh()
>>> array([[0, 1],
          [2, 3],
          [4, 5],
          [6, 7]])
>>> mesh.shape()
OrderedDict([('x', 4), ('y', 2)])
class torch_xla.distributed.spmd.HybridMesh(*, ici_mesh_shape: Tuple[int, ...], dcn_mesh_shape: Optional[Tuple[int, ...]] = None, axis_names: Optional[Tuple[str, ...]] = None)[source]
创建一个由通过 ICI 和 DCN 网络连接的设备组成的混合设备网格。

逻辑网格的形状应按网络强度递增的顺序排列,例如 [replica, data, model],其中 model 具有最高的网络通信需求。

参数
  • ici_mesh_shape – 用于内部连接设备的逻辑网格形状。

  • dcn_mesh_shape – 用于外部连接设备的逻辑网格形状。

示例

>>> # This example is assuming 2 slices of v4-8.
>>> ici_mesh_shape = (1, 4, 1) # (data, fsdp, tensor)
>>> dcn_mesh_shape = (2, 1, 1)
>>> mesh = HybridMesh(ici_mesh_shape, dcn_mesh_shape, ('data','fsdp','tensor'))
>>> print(mesh.shape())
>>> >> OrderedDict([('data', 2), ('fsdp', 4), ('tensor', 1)])

experimental

torch_xla.experimental.eager_mode(enable: bool)[source]

配置 torch_xla 的默认执行模式。

在 eager 模式下,只有被 `torch_xla.compile` 编译的函数才会被跟踪和编译。其他 torch 操作将以 eager 模式执行。

debug

torch_xla.debug.metrics.metrics_report()[source]

检索包含完整的指标和计数器报告的字符串。

torch_xla.debug.metrics.short_metrics_report(counter_names: Optional[list] = None, metric_names: Optional[list] = None)[source]

检索包含完整的指标和计数器报告的字符串。

参数
  • counter_names (list) – 需要打印其数据的计数器名称列表。

  • metric_names (list) – 需要打印其数据的指标名称列表。

torch_xla.debug.metrics.counter_names()[source]

检索所有当前活跃的计数器名称。

torch_xla.debug.metrics.counter_value(name)[source]

返回活跃计数器的值。

参数

name (string) – 需要检索其值的计数器的名称。

返回

计数器值为整数。

torch_xla.debug.metrics.metric_names()[source]

检索所有当前活跃的指标名称。

torch_xla.debug.metrics.metric_data(name)[source]

返回活跃指标的数据。

参数

name (string) – 需要检索其数据的指标名称。

返回

指标数据,它是一个由 (TOTAL_SAMPLES, ACCUMULATOR, SAMPLES) 组成的元组。TOTAL_SAMPLES 是已发布到该指标的总样本数。指标仅保留给定数量的样本(在循环缓冲区中)。ACCUMULATORTOTAL_SAMPLES 中样本的总和。SAMPLES 是一个 (TIME, VALUE) 元组列表。

文档

访问完整的 PyTorch 开发者文档

查看文档

教程

获取适用于初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题解答

查看资源