• 文档 >
  • torch.distributed.tensor
快捷方式

torch.distributed.tensor

注意

torch.distributed.tensor 目前处于 alpha 状态并正在开发中,我们致力于为文档中列出的大多数 API 提供向后兼容性,但如有必要,可能会进行 API 更改。

PyTorch DTensor (分布式张量)

PyTorch DTensor 提供简单而灵活的张量分片原语,可以透明地处理分布式逻辑,包括分片存储、算子计算以及跨设备/主机的集合通信。DTensor 可用于构建不同的并行解决方案,并在处理多维分片时支持分片 state_dict 表示。

请参阅基于 DTensor 构建的 PyTorch 原生并行解决方案的示例

DTensor 遵循 SPMD(单程序,多数据)编程模型,使用户能够编写分布式程序,就像它是具有相同收敛属性的单设备程序一样。它通过指定 DeviceMeshPlacement 提供统一的张量分片布局(DTensor 布局)

  • DeviceMesh 表示设备拓扑和集群的通信器,使用 n 维数组。

  • Placement 描述 DeviceMesh 上逻辑张量的分片布局。DTensor 支持三种类型的 Placement:ShardReplicatePartial

DTensor 类 API

DTensortorch.Tensor 的子类。这意味着一旦创建了 DTensor,它就可以以非常类似于 torch.Tensor 的方式使用,包括运行不同类型的 PyTorch 算子,就像在单个设备中运行它们一样,从而为 PyTorch 算子实现适当的分布式计算。

除了现有的 torch.Tensor 方法外,它还提供了一组额外的方法来与 torch.Tensor 交互,将 DTensor 布局 redistribute 到新的 DTensor,获取所有设备上的完整张量内容等等。

class torch.distributed.tensor.DTensor(local_tensor, spec, *, requires_grad)

DTensor (分布式张量) 是 torch.Tensor 的子类,它提供类似单设备的抽象,用于使用多设备 torch.Tensor 进行编程。它通过 DeviceMesh 和以下类型的 Placement 描述分布式张量分片布局(DTensor 布局)

  • Shard:张量在 DeviceMesh 维度的设备上沿张量维度 dim 分片

  • Replicate:张量在 DeviceMesh 维度的设备上复制

  • Partial:张量在 DeviceMesh 维度的设备上等待规约

当调用 PyTorch 算子时,DTensor 会覆盖 PyTorch 算子以执行分片计算,并在必要时发出通信。伴随算子计算,DTensor 将正确转换或传播 placement(DTensor 布局)(基于算子语义本身)并生成新的 DTensor 输出。

为了确保在调用 PyTorch 算子时 DTensor 分片计算的数值正确性,DTensor 要求算子的每个张量参数都必须是 DTensor。

注意

直接使用张量子类构造函数在这里不是创建 DTensor 的推荐方法(即,它不能正确处理 autograd,因此不是公共 API)。请参阅 create_dtensor 部分,了解如何创建 DTensor

返回类型

DTensor

static from_local(local_tensor, device_mesh=None, placements=None, *, run_check=False, shape=None, stride=None)[source][source]

根据指定的 device_meshplacements,从每个 rank 上的本地 torch.Tensor 创建 DTensor

参数
  • local_tensor (torch.Tensor) – 每个 rank 上的本地 torch.Tensor。

  • device_mesh (DeviceMesh, 可选) – 用于放置张量的 DeviceMesh,如果未指定,则必须在 DeviceMesh 上下文管理器下调用,默认值:None

  • placements (List[Placement], 可选) – 描述如何将本地 torch.Tensor 放置在 DeviceMesh 上的 placement,必须具有与 device_mesh.ndim 相同数量的元素。

关键字参数
  • run_check (bool, 可选) – 以额外的通信为代价,跨 rank 执行健全性检查,以检查每个本地张量的元信息,确保正确性。如果在 placements 中有 Replicate,则 DeviceMesh 维度第一个 rank 上的数据将广播到其他 rank。默认值:False

  • shape (torch.Size, 可选) – 一个整数列表,指定构建在 local_tensor 之上的 DTensor 的大小。请注意,如果 local_tensor 的形状在不同 rank 之间不同,则需要提供此参数。如果未提供,则将计算 shape,假设给定的分布式张量在 rank 之间均匀分片。默认值:None

  • stride (tuple, 可选) – 一个整数列表,指定 DTensor 的步幅。如果未提供,则将计算 stride,假设给定的分布式张量在 rank 之间均匀分片。默认值:None

返回

一个 DTensor 对象

返回类型

DTensor

注意

run_check=False 时,用户有责任确保传入的本地张量在不同 rank 之间是正确的(即,张量对于 Shard(dim) placement 是分片的,或者对于 Replicate() placement 是复制的)。否则,创建的 DTensor 的行为是未定义的。

注意

from_local 是可微分的,创建的 DTensor 对象的 requires_grad 将取决于 local_tensor 是否需要梯度。

full_tensor(*, grad_placements=None)[source][source]

返回此 DTensor 的完整张量。它将执行必要的集合操作,以收集来自其 DeviceMesh 中其他 rank 的本地张量,并将它们连接在一起。它是以下代码的语法糖

dtensor.redistribute(placements=[Replicate()] * mesh.ndim).to_local()

关键字参数

grad_placements (List[Placement], 可选) – placements 描述从此函数返回的完整张量的任何梯度布局的未来布局。full_tensor 将 DTensor 转换为完整的 torch.Tensor,并且返回的 torch.tensor 可能不会在代码的后面用作原始复制的 DTensor 布局。此参数是用户可以给 autograd 的提示,以防返回张量的梯度布局与原始复制的 DTensor 布局不匹配。如果未指定,我们将假设完整张量的梯度布局是复制的。

返回

一个 torch.Tensor 对象,表示此 DTensor 的完整张量。

返回类型

张量

注意

full_tensor 是可微分的。

redistribute(device_mesh=None, placements=None, *, async_op=False)[source][source]

redistribute 执行必要的集合操作,将当前 DTensor 从其当前 placement 重新分布到新的 placement,或者从其当前 DeviceMesh 重新分布到新的 DeviceMesh。即,我们可以通过为 DeviceMesh 的每个维度指定 Replicate placement,将分片 DTensor 转换为复制 DTensor。

当在一个设备网格维度上从当前 placement 重新分布到新的 placement 时,我们将执行以下操作,包括通信集合或本地操作

  1. Shard(dim) -> Replicate(): all_gather

  2. Shard(src_dim) -> Shard(dst_dim): all_to_all

  3. Replicate() -> Shard(dim): 本地分块 (即 torch.chunk)

  4. Partial() -> Replicate(): all_reduce

  5. Partial() -> Shard(dim): reduce_scatter

redistribute 将正确地计算出在一维或多维 DeviceMesh 上创建的 DTensor 所需的重新分布步骤。

参数
  • device_mesh (DeviceMesh, 可选) – 用于放置 DTensor 的 DeviceMesh。如果未指定,它将使用当前 DTensor 的 DeviceMesh。默认值:None

  • placements (List[Placement], 可选) – 描述如何将 DTensor 放置到 DeviceMesh 中的新 placement,必须具有与 device_mesh.ndim 相同数量的元素。默认值:在所有网格维度上复制

关键字参数

async_op (bool, 可选) – 是否异步执行 DTensor 重新分布操作。默认值:False

返回

一个 DTensor 对象

返回类型

DTensor

注意

redistribute 是可微分的,这意味着用户无需担心重新分布操作的反向公式。

注意

redistribute 目前仅支持在同一 DeviceMesh 上重新分布 DTensor,如果您需要在不同的 DeviceMesh 上重新分布 DTensor,请提交 issue。

to_local(*, grad_placements=None)[source][source]

获取此 DTensor 在其当前 rank 上的本地张量。对于分片,它返回逻辑张量视图的本地分片;对于复制,它返回其当前 rank 上的副本。

关键字参数

grad_placements (List[Placement], 可选) – placements 描述从此函数返回的张量的任何梯度布局的未来布局。to_local 将 DTensor 转换为本地张量,并且返回的本地张量可能不会在代码的后面用作原始 DTensor 布局。此参数是用户可以给 autograd 的提示,以防返回张量的梯度布局与原始 DTensor 布局不匹配。如果未指定,我们将假设梯度布局与原始 DTensor 保持不变,并将其用于梯度计算。

返回

一个 torch.TensorAsyncCollectiveTensor 对象。它表示其当前 rank 上的本地张量。当返回 AsyncCollectiveTensor 对象时,表示本地张量尚未准备好(即通信未完成)。在这种情况下,用户需要调用 wait 以等待本地张量准备就绪。

返回类型

张量

注意

to_local 是可微分的,返回的本地张量的 requires_grad 将取决于 DTensor 是否需要梯度。

property device_mesh: DeviceMesh

与此 DTensor 对象关联的 DeviceMesh 属性。

注意

device_mesh 是只读属性,无法设置。

property placements: Tuple[Placement, ...]

此 DTensor 的 placement 属性,描述此 DTensor 在其 DeviceMesh 上的布局。

注意

placements 是只读属性,无法设置。

DeviceMesh 作为分布式通信器

DeviceMesh 是从 DTensor 构建的抽象,用于描述集群的设备拓扑并表示多维通信器(基于 ProcessGroup)。要查看如何创建/使用 DeviceMesh 的详细信息,请参阅 DeviceMesh 食谱

DTensor Placement 类型

DTensor 支持每个 DeviceMesh 维度上的以下类型的 Placement

class torch.distributed.tensor.placement_types.Shard(dim)[source][source]

Shard(dim) placement 描述在张量维度 dim 上,跨相应的 DeviceMesh 维度进行 DTensor 分片,其中 DeviceMesh 维度上的每个 rank 仅保存全局张量的分片/片段。Shard(dim) placement 遵循 torch.chunk(dim) 语义,当张量维度不能在 DeviceMesh 维度上均匀分割时,DeviceMesh 维度上的最后几个分片可能为空。Shard placement 可以被所有 DTensor API 使用(即 distribute_tensor、from_local 等)

参数

dim (int) – 张量维度,描述 DTensor 在其相应的 DeviceMesh 维度上分片。

警告

在张量维度上进行分片,其中张量维度大小不能在 DeviceMesh 维度上均匀分割,目前是实验性的,可能会发生变化。

dim: int
class torch.distributed.tensor.placement_types.Replicate[source][source]

Replicate() placement 描述在相应的 DeviceMesh 维度上进行 DTensor 复制,其中 DeviceMesh 维度上的每个 rank 都保存全局张量的副本。Replicate placement 可以被所有 DTensor API 使用(即 distribute_tensor, DTensor.from_local 等)

class torch.distributed.tensor.placement_types.Partial(reduce_op='sum')[source][source]

Partial(reduce_op) placement 描述了 DTensor 在指定的 DeviceMesh 维度上处于待归约状态,其中 DeviceMesh 维度上的每个 rank 都持有全局 Tensor 的部分值。用户可以使用 redistributePartial DTensor 重分布到指定 DeviceMesh 维度上的 ReplicateShard(dim) placement,这将触发底层必要的通信操作(例如 allreducereduce_scatter)。

参数

reduce_op (str, optional) – 用于 partial DTensor 生成 Replicated/Sharded DTensor 的归约操作。仅支持逐元素归约操作,包括:“sum”、“avg”、“product”、“max”、“min”,默认值:“sum”。

注意

Partial placement 可以作为 DTensor 算子的结果生成,并且只能被 DTensor.from_local API 使用。

reduce_op: str = 'sum'
class torch.distributed.tensor.placement_types.Placement[source][source]

Placement 类型的基类,描述了 DTensor 如何放置在 DeviceMesh 上。PlacementDeviceMesh 一起可以描述 DTensor 的布局(Layout)。它是三种主要的 DTensor Placement 类型:ShardReplicatePartial 的基类。

此类不应直接使用,主要用作类型存根(typing stub)。

is_partial()[source][source]
返回类型

bool

is_replicate()[source][source]
返回类型

bool

is_shard(dim=None)[source][source]
返回类型

bool

创建 DTensor 的不同方法

有三种方法可以构建 DTensor
  • distribute_tensor() 从每个 rank 上的逻辑或“全局” torch.Tensor 创建 DTensor。这可以用于对叶子 torch.Tensor s(即模型参数/缓冲区和输入)进行分片。

  • DTensor.from_local() 从每个 rank 上的本地 torch.Tensor 创建 DTensor,可用于从非叶子 torch.Tensor s(即前向/后向传播期间的中间激活张量)创建 DTensor

  • DTensor 提供了专用的张量工厂函数(例如 empty(), ones(), randn() 等)以允许通过直接指定 DeviceMeshPlacement 来创建不同的 DTensor。与 distribute_tensor() 相比,这可以直接在设备上物化分片内存,而不是在初始化逻辑张量内存后执行分片。

从逻辑 torch.Tensor 创建 DTensor

torch.distributed 中的 SPMD(单程序,多数据)编程模型启动多个进程(即通过 torchrun)来执行相同的程序,这意味着程序内部的模型将首先在不同的进程上初始化(即模型可能在 CPU 或 meta 设备上初始化,或者如果内存足够,则直接在 GPU 上初始化)。

DTensor 提供了一个 distribute_tensor() API,可以将模型权重或张量分片为 DTensor s,它将从每个进程上的“逻辑”张量创建一个 DTensor。这将使创建的 DTensor s 能够符合单设备语义,这对于数值正确性至关重要。

torch.distributed.tensor.distribute_tensor(tensor, device_mesh=None, placements=None)[source]

根据指定的 placements,将叶子 torch.Tensor (即 nn.Parameter/buffers)分布到 device_meshdevice_meshplacements 的 rank 必须相同。要分布的 tensor 是逻辑或“全局”张量,API 将使用 DeviceMesh 维度中第一个 rank 的 tensor 作为真值来源,以保留单设备语义。 如果您想在 Autograd 计算的中间构建 DTensor,请使用 DTensor.from_local() 代替。

参数
  • tensor (torch.Tensor) – 要分布的 torch.Tensor。请注意,如果您想在维度上分片张量,而该维度不能被 mesh 维度中的设备数量均匀整除,我们使用 torch.chunk 语义来分片张量并分散分片。不均匀分片行为是实验性的,可能会发生变化。

  • device_mesh (DeviceMesh, optional) – 用于分布张量的 DeviceMesh,如果未指定,则必须在 DeviceMesh 上下文管理器下调用,默认值:None

  • placements (List[Placement], optional) – 描述如何在 DeviceMesh 上放置张量的 placements,必须具有与 device_mesh.ndim 相同数量的元素。如果未指定,默认情况下,我们将从 device_mesh 每个维度的第一个 rank 复制张量到整个 device_mesh

返回

一个 DTensorXLAShardedTensor 对象。

返回类型

DTensor

注意

当使用 xla device_type 初始化 DeviceMesh 时,distribute_tensor 返回 XLAShardedTensor。有关更多详细信息,请参阅 此 issue。XLA 集成是实验性的,可能会发生变化。

除了 distribute_tensor() 之外,DTensor 还提供了 distribute_module() API,以便更轻松地在 nn.Module 级别进行分片。

torch.distributed.tensor.distribute_module(module, device_mesh=None, partition_fn=None, input_fn=None, output_fn=None)[source]

此函数公开了三个函数来控制模块的参数/输入/输出

1. 通过指定 partition_fn 在运行时执行之前对模块执行分片(即允许用户根据指定的 partition_fn 将 Module 参数转换为 DTensor 参数)。 2. 通过指定 input_fnoutput_fn 来控制模块在运行时执行期间的输入或输出。(即,将输入转换为 DTensor,将输出转换回 torch.Tensor

参数
  • module (nn.Module) – 要分区的用户模块。

  • device_mesh (DeviceMesh) – 用于放置模块的设备网格。

  • partition_fn (Callable) – 用于分区参数的函数(即,跨 device_mesh 分片某些参数)。如果未指定 partition_fn,则默认情况下,我们将跨网格复制 module 的所有模块参数。

  • input_fn (Callable) – 指定输入分布,即可以控制模块的输入如何分片。input_fn 将安装为模块 forward_pre_hook(前向预hook)。

  • output_fn (Callable) – 指定输出分布,即可以控制输出如何分片,或将其转换回 torch.Tensor。output_fn 将安装为模块 forward_hook(后向hook)。

返回

一个模块,其中包含的参数/缓冲区都是 DTensor s。

返回类型

模块

注意

当使用 xla device_type 初始化 DeviceMesh 时,distribute_module 返回带有 PyTorch/XLA SPMD 注释参数的 nn.Module。有关更多详细信息,请参阅 此 issue。XLA 集成是实验性的,可能会发生变化。

DTensor 工厂函数

DTensor 还提供了专用的张量工厂函数,允许直接使用类似于 torch.Tensor 的工厂函数 API(即 torch.ones、torch.empty 等)创建 DTensor,同时额外指定要创建的 DTensorDeviceMeshPlacement

torch.distributed.tensor.zeros(*size, requires_grad=False, dtype=None, layout=torch.strided, device_mesh=None, placements=None)[source]

返回一个填充标量值 0 的 DTensor

参数

size (int...) – 定义输出 DTensor 形状的整数序列。可以是可变数量的参数或类似列表或元组的集合。例如:zeros(1,2,3..) 或 zeros([1,2,3..]) 或 zeros((1,2,3..))

关键字参数
  • requires_grad (bool, optional) – 如果 autograd 应该记录返回的 DTensor 上的操作。默认值:False

  • dtype (torch.dtype, optional) – 返回的 DTensor 的所需数据类型。默认值:如果为 None,则使用全局默认值(请参阅 torch.set_default_dtype())。

  • layout (torch.layout, optional) – 返回的 DTensor 的所需布局。默认值:torch.strided

  • device_meshDeviceMesh 类型,包含 rank 的 mesh 信息

  • placementsPlacement 类型的序列:Shard, Replicate

返回

每个 rank 上的一个 DTensor 对象

返回类型

DTensor

torch.distributed.tensor.ones(*size, dtype=None, layout=torch.strided, requires_grad=False, device_mesh=None, placements=None)[source]

返回一个填充标量值 1 的 DTensor,其形状由可变参数 size 定义。

参数

size (int...) – 定义输出 DTensor 形状的整数序列。可以是可变数量的参数或类似列表或元组的集合。例如:ones(1,2,3..) 或 ones([1,2,3..]) 或 ones((1,2,3..))

关键字参数
  • dtype (torch.dtype, optional) – 返回的 DTensor 的所需数据类型。默认值:如果为 None,则使用全局默认值(请参阅 torch.set_default_dtype())。

  • layout (torch.layout, optional) – 返回的 DTensor 的所需布局。默认值:torch.strided

  • requires_grad (bool, optional) – 如果 autograd 应该记录返回的 DTensor 上的操作。默认值:False

  • device_meshDeviceMesh 类型,包含 rank 的 mesh 信息

  • placementsPlacement 类型的序列:Shard, Replicate

返回

每个 rank 上的一个 DTensor 对象

返回类型

DTensor

torch.distributed.tensor.empty(*size, dtype=None, layout=torch.strided, requires_grad=False, device_mesh=None, placements=None)[source]

返回一个填充了未初始化数据的 DTensorDTensor 的形状由可变参数 size 定义。

参数

size (int...) – 定义输出 DTensor 形状的整数序列。可以是可变数量的参数或类似列表或元组的集合。例如:empty(1,2,3..) 或 empty([1,2,3..]) 或 empty((1,2,3..))

关键字参数
  • dtype (torch.dtype, optional) – 返回的 DTensor 的所需数据类型。默认值:如果为 None,则使用全局默认值(请参阅 torch.set_default_dtype())。 layout (torch.layout, optional): 返回的 DTensor 的所需布局。默认值:torch.strided

  • requires_grad (bool, optional) – 如果 autograd 应该记录返回的 DTensor 上的操作。默认值:False

  • device_meshDeviceMesh 类型,包含 rank 的 mesh 信息

  • placementsPlacement 类型的序列:Shard, Replicate

返回

每个 rank 上的一个 DTensor 对象

返回类型

DTensor

torch.distributed.tensor.full(size, fill_value, *, dtype=None, layout=torch.strided, requires_grad=False, device_mesh=None, placements=None)[source]

返回一个根据 device_meshplacements 填充了 fill_valueDTensor,其形状由参数 size 定义。

参数
  • size (int...) – 定义输出 DTensor 形状的整数序列。可以是可变数量的参数或类似列表或元组的集合。例如:ones(1,2,3..) 或 ones([1,2,3..]) 或 ones((1,2,3..))

  • fill_value (Scalar) – 用于填充输出张量的值。

关键字参数
  • dtype (torch.dtype, optional) – 返回的 DTensor 的所需数据类型。默认值:如果为 None,则使用全局默认值(请参阅 torch.set_default_dtype())。

  • layout (torch.layout, optional) – 返回的 DTensor 的所需布局。默认值:torch.strided

  • requires_grad (bool, optional) – 如果 autograd 应该记录返回的 DTensor 上的操作。默认值:False

  • device_meshDeviceMesh 类型,包含 rank 的 mesh 信息。

  • placementsPlacement 类型的序列:Shard, Replicate

返回

每个 rank 上的一个 DTensor 对象

返回类型

DTensor

torch.distributed.tensor.rand(*size, requires_grad=False, dtype=None, layout=torch.strided, device_mesh=None, placements=None)[source]

返回一个 DTensor,其中填充了从区间 [0, 1) 上的均匀分布中抽取的随机数。张量的形状由可变参数 size 定义。

参数

size (int...) – 定义输出 DTensor 形状的整数序列。可以是可变数量的参数或类似列表或元组的集合。例如:ones(1,2,3..) 或 ones([1,2,3..]) 或 ones((1,2,3..))

关键字参数
  • dtype (torch.dtype, optional) – 返回的 DTensor 的所需数据类型。默认值:如果为 None,则使用全局默认值(请参阅 torch.set_default_dtype())。

  • layout (torch.layout, optional) – 返回的 DTensor 的所需布局。默认值:torch.strided

  • requires_grad (bool, optional) – 如果 autograd 应该记录返回的 DTensor 上的操作。默认值:False

  • device_meshDeviceMesh 类型,包含 rank 的 mesh 信息。

  • placementsPlacement 类型的序列:Shard, Replicate

返回

每个 rank 上的一个 DTensor 对象

返回类型

DTensor

torch.distributed.tensor.randn(*size, requires_grad=False, dtype=None, layout=torch.strided, device_mesh=None, placements=None)[source]

返回一个 DTensor,其中填充了从均值为 0 和方差为 1 的正态分布中抽取的随机数。张量的形状由可变参数 size 定义。

参数

size (int...) – 定义输出 DTensor 形状的整数序列。可以是可变数量的参数或类似列表或元组的集合。例如:ones(1,2,3..) 或 ones([1,2,3..]) 或 ones((1,2,3..))

关键字参数
  • dtype (torch.dtype, optional) – 返回的 DTensor 的所需数据类型。默认值:如果为 None,则使用全局默认值(请参阅 torch.set_default_dtype())。

  • layout (torch.layout, optional) – 返回的 DTensor 的所需布局。默认值:torch.strided

  • requires_grad (bool, optional) – 如果 autograd 应该记录返回的 DTensor 上的操作。默认值:False

  • device_meshDeviceMesh 类型,包含 rank 的 mesh 信息。

  • placementsPlacement 类型的序列:Shard, Replicate

返回

每个 rank 上的一个 DTensor 对象

返回类型

DTensor

调试

日志记录

启动程序时,您可以使用来自 torch._loggingTORCH_LOGS 环境变量来开启额外的日志记录。

  • TORCH_LOGS=+dtensor 将显示 logging.DEBUG 消息以及所有更高级别的消息。

  • TORCH_LOGS=dtensor 将显示 logging.INFO 消息及更高级别的消息。

  • TORCH_LOGS=-dtensor 将显示 logging.WARNING 消息及更高级别的消息。

调试工具

为了调试应用了 DTensor 的程序,并更详细地了解幕后发生的集合通信,DTensor 提供了 CommDebugMode

class torch.distributed.tensor.debug.CommDebugMode

CommDebugMode 是一个上下文管理器,用于计算其上下文中的功能性集合通信的数量。它使用 TorchDispatchMode 来实现这一点。

用法示例

mod = ...
comm_mode = CommDebugMode()
with comm_mode:
    mod.sum().backward()
print(comm_mode.get_comm_counts())
generate_comm_debug_tracing_table(noise_level=3)[source][source]

生成详细的表格,显示模块级别的操作和集合通信跟踪信息。信息量取决于 noise_level

  1. 打印模块级别的集合通信计数

  2. 打印未包含在琐碎操作中的 dTensor 操作、模块信息

  3. 打印未包含在琐碎操作中的操作

  4. 打印所有操作

generate_json_dump(file_name='comm_mode_log.json', noise_level=3)[source][source]

创建用于构建浏览器可视化的 json 文件。0. 打印模块级别的集合通信计数 1. 打印未包含在琐碎操作中的 dTensor 操作 2. 打印未包含在琐碎操作中的操作 3. 打印所有操作

get_comm_counts()[source][source]

以字典形式返回通信计数。

返回

以字典形式表示的通信计数。

返回类型

Dict[Any, int]

get_parameter_info()[source][source]
返回类型

Dict[str, Dict[str, Any]]

get_sharding_info()[source][source]
返回类型

Dict[str, Dict[str, Any]]

get_total_counts()[source][source]
返回类型

int

log_comm_debug_tracing_table_to_file(file_name='comm_mode_log.txt', noise_level=3)[source][source]

作为控制台 CommDebugMode 输出的替代方案,写入用户指定的文件

为了可视化维度小于 3 的 DTensor 的分片,DTensor 提供了 visualize_sharding()

torch.distributed.tensor.debug.visualize_sharding(dtensor, header='')[source]

在终端中可视化 1D 或 2D DTensor 的分片。

注意

这需要 tabulate 包。对于空张量,将不会打印任何分片信息

实验性功能

DTensor 还提供了一组实验性功能。这些功能要么处于原型设计阶段,要么基本功能已完成,但正在寻求用户反馈。如果您对这些功能有任何反馈,请向 PyTorch 提交 issue。

torch.distributed.tensor.experimental.context_parallel(mesh, *, buffers=None, buffer_seq_dims=None, no_restore_buffers=None)[source]

context_parallel 是一个实验性 API,用于启用上下文并行 (CP)。此 API 执行两个操作:1) 使用 CP 启用的 SDPA (torch.nn.functional.scaled_dot_product_attention) 补丁 SDPA,2) 沿着序列维度分片 buffers,并且每个 rank 将根据 mesh 保留相应的分片。

参数
  • mesh (DeviceMesh) – 用于上下文并行的设备网格。

  • buffers (Optional[List[torch.Tensor]]) – 使用情况取决于序列维度的缓冲区。示例包括输入批次、标签和位置嵌入缓冲区。这些缓冲区必须沿着序列维度进行分片,以确保准确性。分片将就地发生,缓冲区的形状将在上下文中更改。缓冲区将在上下文结束后恢复。no_restore_buffers 可用于指定哪些缓冲区不需要恢复。请注意,buffers 不应包含任何 nn.Parameter。

  • buffer_seq_dims (Optional[List[int]]) – buffers 的序列维度。

  • no_restore_buffers (Optional[Set[torch.Tensor]]) – 此集合中的缓冲区在上下文退出后将不会恢复。此集合必须是 buffers 的子集。如果缓冲区在上下文退出后不再使用,则可以将这些缓冲区放入此列表以避免额外的恢复时间。

返回类型

Generator[None, None, None]

警告

torch.distributed._tensor.experimental.attention.context_parallel 是 PyTorch 中的一个原型功能。该 API 可能会发生变化。

torch.distributed.tensor.experimental.local_map(func, out_placements, in_placements=None, device_mesh=None, *, redistribute_inputs=False)[source]

local_map() 是一个实验性 API,允许用户将 DTensor 传递给一个旨在应用于 torch.Tensor 的函数。它通过提取 DTensor 的本地组件,调用该函数,并根据 out_placements 将输出包装到 DTensor 来实现。

参数
  • func (Callable) – 要应用于 DTensor 的每个本地分片的函数。

  • out_placements (Union[PlacementType, Tuple[PlacementType, …]]) – DTensorfunc 的扁平化输出中所需的放置位置。如果扁平化的 output 是单个值,则 out_placements 应为 PlacementType 类型。否则,如果扁平化的 output 具有多个值,则 out_placements 应为 PlacementType 值的元组,与扁平化的 output 一对一映射。此外,对于 Tensor 输出,我们使用 PlacementType 作为其放置位置(Tuple[Placement] 值)。对于非 Tensor 输出,PlacementType 应为 None。请注意,唯一的例外是没有传入 DTensor 参数的情况。在这种情况下,即使 out_placements 不是 None,结果函数也应忽略所需的放置位置,因为该函数不是使用 DTensor 运行的。

  • in_placements (Tuple[PlacementType, …], optional) – DTensorfunc 的扁平化输入中所需的放置位置。如果指定了 in_placementslocal_map() 将检查每个 DTensor 参数的放置位置是否与所需的放置位置相同。如果放置位置不同且 redistribute_inputsFalse,则会引发异常。否则,如果 redistribute_inputsTrue,则会首先将参数重新分发到所需的分片放置位置,然后再将其本地张量传递给 func。唯一的例外是当所需的放置位置不是 None 且参数是 torch.Tensor 时。在这种情况下,将跳过放置位置检查,并且参数将直接传递给 func。如果 in_placementsNone,则不会执行任何放置位置检查。默认值:None

  • device_mesh (DeviceMesh, optional) – 所有 DTensor 所在的设备网格。如果未指定,则将从输入 DTensor 的设备网格推断。local_map 要求每个 DTensor 都放置在同一设备网格上。默认值:None。

  • redistribute_inputs (bool, optional) – 布尔值,指示当输入 DTensor 的放置位置与所需的输入放置位置不同时,是否重新分片输入 DTensor。如果此值为 False 并且某些 DTensor 输入的放置位置不同,则会引发异常。默认值:False。

返回

一个 Callable,它将 func 应用于输入 DTensor 的每个本地分片,并返回一个从 func 的返回值构造的 DTensor

Raises
  • AssertionError – 如果输入 DTensor 未放置在同一设备网格上,或者如果它们放置在与传入的 device_mesh 参数不同的设备网格上。

  • AssertionError – 对于任何非 DTensor 输出,我们要求其在 out_placements 中的相应输出放置位置为 None。如果不是这种情况,将引发 AssertionError。

  • ValueError – 如果 redistribute_inputs=False,但输入 DTensor 需要根据 in_placements 进行重新分发。

示例

>>> def mm_allreduce_forward(device_mesh, W, X):
>>>     partial_sum_tensor = torch.mm(W, X)
>>>     reduced_tensor = funcol.all_reduce(partial_sum_tensor, "sum", device_mesh)
>>>     return reduced_tensor
>>>
>>> W = torch.randn(12, 8, requires_grad=False)
>>> X = torch.randn(8, 16, requires_grad=False)
>>> Y = torch.mm(W, X)
>>> row_wise = [Shard(0)]  # row-wise sharding placements on 1-d mesh
>>> col_wise = [Shard(1)]  # col-wise sharding placements on 1-d mesh
>>>
>>> # local_mm_allreduce_forward is the function wrapped with DTensor/Tensor convertion
>>> local_mm_allreduce_forward = local_map(
>>>     mm_allreduce_forward,
>>>     out_placements=[Replicate()],
>>>     in_placements=[col_wise, row_wise],
>>>     device_mesh=device_mesh,
>>> )
>>>
>>> W_dt = distribute_tensor(W, device_mesh, (col_wise))  # col-wisely sharded W tensor
>>> X_dt = distribute_tensor(X, device_mesh, (row_wise))  # row-wisely sharded X tensor
>>> Y_dt = local_mm_allreduce_forward(device_mesh, W_dt, X_dt)  # apply local_mm_allreduce_forward to DTensors

注意

此 API 目前是实验性的,可能会发生更改

torch.distributed.tensor.experimental.register_sharding(op)[source]

register_sharding() 是一个实验性 API,允许用户在张量输入和输出为 DTensor 时为运算符注册分片策略。在以下情况下,它可能很有用:(1)op 不存在默认分片策略,例如,当 opDTensor 不支持的自定义运算符时;(2)当用户想要覆盖现有运算符的默认分片策略时。

参数

op (Union[OpOverload, List[OpOverload]]) – 要注册自定义分片函数的运算符或运算符列表。

返回

一个函数装饰器,可用于包装一个函数,该函数定义了在 op 中指定的运算符的分片策略。定义的分片策略将注册到 DTensor,并且如果 DTensor 已经实现了该运算符,则将覆盖默认分片策略。自定义分片函数接受与原始运算符相同的输入(除非参数是 torch.Tensor,否则它将被 DTensor 内部使用的类似张量的对象替换)。该函数应返回一个 2 元组序列,每个元组指定可接受的输出放置位置及其对应的输入放置位置。

示例

>>> @register_sharding(aten._softmax.default)
>>> def custom_softmax_sharding(x, dim, half_to_float):
>>>     softmax_dim = dim if dim >= 0 else dim + x.ndim
>>>     acceptable_shardings = []
>>>
>>>     all_replicate = ([Replicate()], [Replicate(), None, None])
>>>     acceptable_shardings.append(all_replicate)
>>>
>>>     for sharding_dim in range(x.ndim):
>>>         if sharding_dim != softmax_dim:
>>>             all_sharded = (
>>>                 [Shard(sharding_dim)],
>>>                 [Shard(sharding_dim), None, None],
>>>             )
>>>             acceptable_shardings.append(all_sharded)
>>>
>>>     return acceptable_shardings

注意

此 API 目前是实验性的,可能会发生更改

文档

访问 PyTorch 的综合开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得您的问题解答

查看资源