torch.distributed.tensor¶
注意
torch.distributed.tensor
目前处于 alpha 阶段,正在开发中。我们致力于为文档中列出的大多数 API 提供向后兼容性,但在必要时可能会进行 API 更改。
PyTorch DTensor (分布式张量)¶
PyTorch DTensor 提供了简单灵活的张量分片原语,可透明地处理分布式逻辑,包括跨设备/主机的分片存储、算子计算和集合通信。DTensor
可用于构建不同的并行解决方案,并在处理多维分片时支持分片 state_dict
表示。
请参阅基于 DTensor
构建的 PyTorch 原生并行解决方案示例:
DTensor
遵循 SPMD(单程序多数据)编程模型,使用户能够像编写具有相同收敛属性的单设备程序一样编写分布式程序。它通过指定 DeviceMesh
和 Placement
提供统一的张量分片布局(DTensor Layout)。
DeviceMesh
使用 n 维数组表示集群的设备拓扑和通信器。Placement
描述了逻辑张量在DeviceMesh
上的分片布局。DTensor 支持三种类型的 Placement:Shard
、Replicate
和Partial
。
DTensor 类 API¶
DTensor
是 torch.Tensor
的子类。这意味着一旦创建 DTensor
,就可以以与 torch.Tensor
非常相似的方式使用它,包括运行不同类型的 PyTorch 算子,就像在单个设备上运行一样,从而实现 PyTorch 算子的正确分布式计算。
除了现有的 torch.Tensor
方法外,它还提供了一组额外的方法来与 torch.Tensor
交互、将 DTensor Layout redistribute
到新的 DTensor、获取所有设备上的完整张量内容等。
- class torch.distributed.tensor.DTensor(local_tensor, spec, *, requires_grad)¶
DTensor
(分布式张量)是torch.Tensor
的子类,提供类似单设备的抽象,用于对多设备torch.Tensor
进行编程。它通过DeviceMesh
和以下类型的Placement
来描述分布式张量分片布局(DTensor Layout):Shard
:张量在DeviceMesh
维度的设备上按张量维度dim
进行分片Replicate
:张量在DeviceMesh
维度的设备上进行复制Partial
:张量在DeviceMesh
维度的设备上等待归约
调用 PyTorch 算子时,
DTensor
会覆盖 PyTorch 算子以执行分片计算并在必要时发出通信。除了算子计算之外,DTensor
会正确转换或传播 Placement(DTensor Layout)(基于算子本身的语义)并生成新的DTensor
输出。为了确保调用 PyTorch 算子时
DTensor
分片计算的数值正确性,DTensor
要求算子的每个张量参数都是 DTensor。注意
不建议在此处直接使用 Tensor 子类构造函数创建
DTensor
(因为它无法正确处理自动微分,因此不是公共 API)。请参阅 create_dtensor 部分以了解如何创建DTensor
。- 返回类型
- __create_chunk_list__()[source][source]¶
返回一个 ChunkStorageMetadata 列表,这是一个描述当前 rank 上本地分片/复制品的大小/偏移量的数据类。对于 DTensor,每个 rank 将拥有一个本地分片/复制品,因此返回列表通常只有一个元素。
这个 dunder 方法主要用于分布式检查点目的。
- 返回值
一个 List[
ChunkStorageMetadata
] 对象,表示当前 rank 上的分片大小/偏移量。
- static from_local(local_tensor, device_mesh=None, placements=None, *, run_check=False, shape=None, stride=None)[source][source]¶
根据指定的
device_mesh
和placements
,从每个 rank 上的本地torch.Tensor
创建一个DTensor
。- 参数
local_tensor (torch.Tensor) – 每个 rank 上的本地 torch.Tensor。
device_mesh (
DeviceMesh
, optional) – 放置张量的 DeviceMesh,如果未指定,则必须在 DeviceMesh 上下文管理器下调用,默认值:Noneplacements (List[
Placement
], optional) – 描述如何将本地 torch.Tensor 放置在 DeviceMesh 上的 placements,必须与device_mesh.ndim
具有相同数量的元素。
- 关键字参数
run_check (bool, optional) – 以额外通信为代价,执行跨 rank 的健全性检查,检查每个本地张量的元信息以确保正确性。如果
placements
中有Replicate
,Device Mesh 维度上第一个 rank 的数据将被广播到其他 rank。默认值:Falseshape (torch.Size, optional) – 一个整数列表,指定基于 local_tensor 构建的 DTensor 的大小。请注意,如果
local_tensor
的形状在不同 rank 上不同,则需要提供此参数。如果未提供,shape
将假设给定分布式张量在 rank 之间均匀分片来计算。默认值:Nonestride (tuple, optional) – 一个整数列表,指定 DTensor 的步长。如果未提供,
stride
将假设给定分布式张量在 rank 之间均匀分片来计算。默认值:None
- 返回值
一个
DTensor
对象- 返回类型
注意
当
run_check=False
时,用户有责任确保传入的本地张量在跨 rank 上是正确的(即对于Shard(dim)
Placement 张量被分片,或者对于Replicate()
Placement 张量被复制)。如果不是,则创建的 DTensor 的行为是未定义的。注意
from_local
是可微分的,创建的 DTensor 对象的 requires_grad 将取决于 local_tensor 是否需要梯度。
- full_tensor(*, grad_placements=None)[source][source]¶
返回此 DTensor 的完整张量。它将执行必要的集合操作,从其 DeviceMesh 中的其他 rank 收集本地张量并将它们拼接在一起。它是以下代码的语法糖:
dtensor.redistribute(placements=[Replicate()] * mesh.ndim).to_local()
- 关键字参数
grad_placements (List[
Placement
], optional) – placements 描述了此函数返回的完整张量的任何梯度布局的未来布局。full_tensor 将 DTensor 转换为完整的 torch.Tensor,并且返回的 torch.tensor 可能在代码后续部分不会像原始复制的 DTensor layout 一样使用。此参数是用户可以提供给自动微分的提示,以防返回张量的梯度布局与原始复制的 DTensor layout 不匹配。如果未指定,我们将假定完整张量的梯度布局被复制。- 返回值
一个
torch.Tensor
对象,表示此 DTensor 的完整张量。- 返回类型
注意
full_tensor
是可微分的。
- redistribute(device_mesh=None, placements=None, *, async_op=False)[source][source]¶
redistribute
执行必要的集合操作,将当前 DTensor 从其当前 placements 重新分布到新的 placements,或者从其当前 DeviceMesh 重新分布到新的 DeviceMesh。例如,我们可以通过为 DeviceMesh 的每个维度指定 Replicate Placement,将一个 Sharded DTensor 转换为 Replicated DTensor。当从一个 Device Mesh 维度上的当前 placements 重新分布到新的 placements 时,我们将执行以下操作,包括通信集合或局部操作:
Shard(dim)
->Replicate()
:all_gather
Shard(src_dim)
->Shard(dst_dim)
:all_to_all
Replicate()
->Shard(dim)
: 局部分块(例如torch.chunk
)Partial()
->Replicate()
:all_reduce
Partial()
->Shard(dim)
:reduce_scatter
redistribute
会正确地找出针对 1-D 或 N-D DeviceMesh 上创建的 DTensor 的必要重新分布步骤。- 参数
device_mesh (
DeviceMesh
, optional) – 放置 DTensor 的 DeviceMesh。如果未指定,将使用当前 DTensor 的 DeviceMesh。默认值:Noneplacements (List[
Placement
], optional) – 描述如何将 DTensor 放置到 DeviceMesh 中的新 placements,必须与device_mesh.ndim
具有相同数量的元素。默认值:在所有 Device Mesh 维度上复制
- 关键字参数
async_op (bool, optional) – 是否异步执行 DTensor 重新分布操作。默认值:False
- 返回值
一个
DTensor
对象- 返回类型
注意
redistribute
是可微分的,这意味着用户无需担心重新分布操作的向后公式。注意
redistribute
目前仅支持在同一 DeviceMesh 上重新分布 DTensor。如果您需要将 DTensor 重新分布到不同的 DeviceMesh,请提交一个 issue。
- to_local(*, grad_placements=None)[source][source]¶
获取此 DTensor 在其当前 rank 上的本地张量。对于分片(sharding),它返回逻辑张量视图的本地分片(local shard),对于复制(replication),它返回其当前 rank 上的复制品(replica)。
- 关键字参数
grad_placements (List[
Placement
], optional) – placements 描述了此函数返回的张量的任何梯度布局的未来布局。to_local 将 DTensor 转换为本地张量,并且返回的本地张量在代码后续部分可能不会像原始 DTensor layout 一样使用。此参数是用户可以提供给自动微分的提示,以防返回张量的梯度布局与原始 DTensor layout 不匹配。如果未指定,我们将假定梯度布局与原始 DTensor 保持相同,并将其用于梯度计算。- 返回值
一个
torch.Tensor
或AsyncCollectiveTensor
对象。它表示在其当前 rank 上的本地张量。当返回AsyncCollectiveTensor
对象时,意味着本地张量尚未准备好(即通信尚未完成)。在这种情况下,用户需要调用wait
来等待本地张量准备好。- 返回类型
注意
to_local
是可微分的,返回的本地张量的requires_grad
将取决于 DTensor 是否需要梯度。
- property device_mesh: DeviceMesh¶
与此 DTensor 对象关联的
DeviceMesh
属性。注意
device_mesh
是只读属性,无法设置。
- property placements: tuple[torch.distributed.tensor.placement_types.Placement, ...]¶
此 DTensor 的 placements 属性,描述了此 DTensor 在其 DeviceMesh 上的布局。
注意
placements
是只读属性,无法设置。
作为分布式通信器的 DeviceMesh¶
DeviceMesh
是从 DTensor 构建的抽象,用于描述集群的设备拓扑并表示多维通信器(基于 ProcessGroup
)。有关如何创建/使用 DeviceMesh 的详细信息,请参阅 DeviceMesh recipe。
DTensor Placement 类型¶
DTensor 支持在每个 DeviceMesh
维度上使用以下类型的 Placement
:
- class torch.distributed.tensor.placement_types.Shard(dim)[source][source]¶
Shard(dim)
Placement 描述了 DTensor 在对应DeviceMesh
维度上按张量维度dim
进行的分片,其中 DeviceMesh 维度上的每个 rank 仅持有全局张量的一个分片/部分。Shard(dim)
Placement 遵循torch.chunk(dim)
的语义,当张量维度不能在 DeviceMesh 维度上均匀整除时,DeviceMesh 维度上的最后几个分片可能是空的。Shard
Placement 可以用于所有 DTensor API(例如distribute_tensor
、from_local
等)。- 参数
dim (int) – 描述 DTensor 在其对应 DeviceMesh 维度上分片的张量维度。
警告
在张量维度大小不能被 DeviceMesh 维度均匀整除的情况下在该张量维度上进行分片,目前是实验性的且可能会发生变化。
- class torch.distributed.tensor.placement_types.Replicate[source][source]¶
Replicate()
布局描述了 DTensor 在对应的DeviceMesh
维度上进行复制,在该 DeviceMesh 维度的每个进程上都持有一个全局张量的副本。所有 DTensor API(例如distribute_tensor
、DTensor.from_local
等)都可以使用Replicate
布局。
- class torch.distributed.tensor.placement_types.Partial(reduce_op='sum')[source][source]¶
Partial(reduce_op)
布局描述了 DTensor 在指定的DeviceMesh
维度上等待归约,在该 DeviceMesh 维度的每个进程上都持有全局张量的部分值。用户可以使用redistribute
将Partial
DTensor 重新分布到指定的DeviceMesh
维度上的Replicate
或Shard(dim)
布局,这将触发底层的必要通信操作(例如allreduce
、reduce_scatter
)。- 参数
reduce_op (str, optional) – 用于将 partial DTensor 归约生成 Replicated/Sharded DTensor 的归约操作。仅支持逐元素归约操作,包括:“sum”、“avg”、“product”、“max”、“min”,默认值:“sum”。
注意
Partial
布局可以作为 DTensor 算子的结果生成,并且只能由DTensor.from_local
API 使用。
创建 DTensor 的不同方法¶
- 构造
DTensor
有三种方法 distribute_tensor()
从每个进程上的逻辑或“全局”torch.Tensor
创建一个DTensor
。这可用于对叶子torch.Tensor
(例如模型参数/buffers 和输入)进行分片。DTensor.from_local()
从每个进程上的本地torch.Tensor
创建一个DTensor
,这可用于从非叶子torch.Tensor
(例如前向/后向过程中的中间激活张量)创建DTensor
。DTensor 提供了专用的张量工厂函数(例如
empty()
、ones()
、randn()
等),允许通过直接指定DeviceMesh
和Placement
来创建不同的DTensor
。与distribute_tensor()
相比,这可以直接在设备上具体化分片内存,而无需在初始化逻辑张量内存后再执行分片。
从逻辑 torch.Tensor 创建 DTensor¶
torch.distributed
中的 SPMD(单程序多数据)编程模型启动多个进程(例如通过 torchrun
)执行相同的程序,这意味着程序内的模型将首先在不同进程上初始化(例如模型可能初始化在 CPU 上、meta 设备上,或者如果内存足够直接初始化在 GPU 上)。
DTensor
提供了一个 distribute_tensor()
API,可以将模型权重或张量分片到 DTensor
中,它会从每个进程上的“逻辑”张量创建一个 DTensor。这将使创建的 DTensor
遵守单设备语义,这对于数值正确性至关重要。
- torch.distributed.tensor.distribute_tensor(tensor, device_mesh=None, placements=None, *, src_data_rank=0)[source]¶
根据指定的
placements
将叶子torch.Tensor
(例如 nn.Parameter/buffers)分布到device_mesh
。device_mesh
和placements
的秩(rank)必须相同。要分布的tensor
是逻辑或“全局”张量,API 将使用 DeviceMesh 维度第一个进程的tensor
作为真实来源,以保留单设备语义。如果你想在 Autograd 计算过程中构造一个 DTensor,请改用DTensor.from_local()
。- 参数
tensor (torch.Tensor) – 要分布的 torch.Tensor。请注意,如果你想在一个维度上进行分片,而该维度的大小不能被该 mesh 维度上的设备数量均匀整除,我们将使用
torch.chunk
语义来分片张量并分散分片。不均匀分片的行为是实验性的,可能会发生变化。device_mesh (
DeviceMesh
, optional) – 用于分布张量的 DeviceMesh,如果未指定,必须在 DeviceMesh 上下文管理器下调用,默认值:Noneplacements (List[
Placement
], optional) – 描述如何在 DeviceMesh 上放置张量的布局列表,其元素数量必须与device_mesh.ndim
相同。如果未指定,默认情况下,我们将从 device_mesh 的每个维度的第一个进程处复制张量到整个 device_mesh。
- 关键字参数
src_data_rank (int, optional) – 逻辑/全局张量的源数据进程。
distribute_tensor()
使用它来向其他进程分散/广播分片/副本。默认情况下,我们使用每个 DeviceMesh 维度上的group_rank=0
作为源数据,以保留单设备语义。如果显式传递None
,则distribute_tensor()
只使用其本地数据,而不是尝试通过分散/广播保留单设备语义。默认值:0- 返回值
一个
DTensor
或XLAShardedTensor
对象。- 返回类型
注意
当使用
xla
设备类型初始化 DeviceMesh 时,distribute_tensor
将返回XLAShardedTensor
。有关更多详细信息,请参见 此 issue。XLA 集成是实验性的,可能会发生变化。
除了 distribute_tensor()
之外,DTensor 还提供了 distribute_module()
API,以便在 nn.Module
级别上更容易地进行分片
- torch.distributed.tensor.distribute_module(module, device_mesh=None, partition_fn=None, input_fn=None, output_fn=None)[source]¶
该函数暴露了三个函数来控制 module 的 parameters/inputs/outputs
1. 通过指定
partition_fn
在运行时执行之前对 module 执行分片(例如允许用户根据指定的partition_fn
将 Module 参数转换为DTensor
参数)。 2. 通过指定input_fn
和output_fn
在运行时执行期间控制 module 的输入或输出(例如将输入转换为DTensor
,将输出转换回torch.Tensor
)。- 参数
module (
nn.Module
) – 用户要进行分区的 module。device_mesh (
DeviceMesh
) – 用于放置 module 的设备 mesh。partition_fn (Callable) – 分区参数的函数(例如将某些参数分片到
device_mesh
上)。如果未指定partition_fn
,默认情况下我们将 module 的所有 module 参数复制到整个 mesh 上。input_fn (Callable) – 指定输入分布,例如可以控制 module 的输入如何分片。
input_fn
将作为 module 的forward_pre_hook
(前向预处理钩子)安装。output_fn (Callable) – 指定输出分布,例如可以控制输出如何分片,或将其转换回 torch.Tensor。
output_fn
将作为 module 的forward_hook
(前向处理钩子)安装。
- 返回值
一个包含所有参数/buffers 都是
DTensor
的 module。- 返回类型
注意
当使用
xla
设备类型初始化 DeviceMesh 时,distribute_module
将返回带有 PyTorch/XLA SPMD 注解参数的 nn.Module。有关更多详细信息,请参见 此 issue。XLA 集成是实验性的,可能会发生变化。
DTensor 工厂函数¶
DTensor 还提供了专用的张量工厂函数,允许通过直接使用类似 torch.Tensor 工厂函数 API(例如 torch.ones、torch.empty 等)创建 DTensor
,方法是额外指定创建的 DTensor
的 DeviceMesh
和 Placement
。与 distribute_tensor()
相比,这可以直接在设备上具体化分片内存,而无需在初始化逻辑张量内存后再执行分片。
- torch.distributed.tensor.zeros(*size, requires_grad=False, dtype=None, layout=torch.strided, device_mesh=None, placements=None)[source]¶
返回一个填充标量值 0 的
DTensor
。- 参数
size (int...) – 定义输出
DTensor
形状的整数序列。可以是可变数量的参数或列表、元组等集合。例如:zeros(1,2,3..) 或 zeros([1,2,3..]) 或 zeros((1,2,3..))- 关键字参数
requires_grad (bool, optional) – 如果 autograd 应该记录返回的
DTensor
上的操作。默认值:False
。dtype (
torch.dtype
, optional) – 返回的DTensor
所需的数据类型。默认值:如果为None
,则使用全局默认值(参见torch.set_default_dtype()
)。layout (
torch.layout
, optional) – 返回的DTensor
所需的布局。默认值:torch.strided
。device_mesh –
DeviceMesh
类型,包含进程的 mesh 信息placements –
Placement
类型的序列:Shard
、Replicate
- 返回值
每个进程上的
DTensor
对象- 返回类型
- torch.distributed.tensor.ones(*size, dtype=None, layout=torch.strided, requires_grad=False, device_mesh=None, placements=None)[source]¶
返回一个填充标量值 1 的
DTensor
,其形状由可变参数size
定义。- 参数
size (int...) – 定义输出
DTensor
形状的整数序列。可以是可变数量的参数或列表、元组等集合。例如:ones(1,2,3..) 或 ones([1,2,3..]) 或 ones((1,2,3..))- 关键字参数
dtype (
torch.dtype
, optional) – 返回的DTensor
所需的数据类型。默认值:如果为None
,则使用全局默认值(参见torch.set_default_dtype()
)。layout (
torch.layout
, optional) – 返回的 DTensor 所需的布局。默认值:torch.strided
。requires_grad (bool, optional) – 如果 autograd 应该记录返回的
DTensor
上的操作。默认值:False
。device_mesh –
DeviceMesh
类型,包含进程的 mesh 信息placements –
Placement
类型的序列:Shard
、Replicate
- 返回值
每个进程上的
DTensor
对象- 返回类型
- torch.distributed.tensor.empty(*size, dtype=None, layout=torch.strided, requires_grad=False, device_mesh=None, placements=None)[source]¶
返回一个填充未初始化数据的
DTensor
。DTensor
的形状由可变参数size
定义。- 参数
size (int...) – 定义输出
DTensor
形状的整数序列。可以是可变数量的参数或列表、元组等集合。例如:empty(1,2,3..) 或 empty([1,2,3..]) 或 empty((1,2,3..))- 关键字参数
dtype (
torch.dtype
, optional) – 返回的DTensor
所需的数据类型。默认值:如果为None
,则使用全局默认值(参见torch.set_default_dtype()
)。layout (torch.layout
, optional):返回的DTensor
所需的布局。默认值:torch.strided
。requires_grad (bool, optional) – 如果 autograd 应该记录返回的
DTensor
上的操作。默认值:False
。device_mesh –
DeviceMesh
类型,包含进程的 mesh 信息placements –
Placement
类型的序列:Shard
、Replicate
- 返回值
每个进程上的
DTensor
对象- 返回类型
- torch.distributed.tensor.full(size, fill_value, *, dtype=None, layout=torch.strided, requires_grad=False, device_mesh=None, placements=None)[source]¶
根据
device_mesh
和placements
返回一个填充fill_value
的DTensor
,其形状由参数size
定义。- 参数
- 关键字参数
dtype (
torch.dtype
, optional) – 返回的DTensor
所需的数据类型。默认值:如果为None
,则使用全局默认值(参见torch.set_default_dtype()
)。layout (
torch.layout
, optional) – 返回的 DTensor 所需的布局。默认值:torch.strided
。requires_grad (bool, optional) – 如果 autograd 应该记录返回的
DTensor
上的操作。默认值:False
。device_mesh –
DeviceMesh
类型,包含进程的 mesh 信息。placements –
Placement
类型的序列:Shard
、Replicate
- 返回值
每个进程上的
DTensor
对象- 返回类型
- torch.distributed.tensor.rand(*size, requires_grad=False, dtype=None, layout=torch.strided, device_mesh=None, placements=None)[source]¶
返回一个填充在区间
[0, 1)
内均匀分布的随机数的DTensor
。张量的形状由可变参数size
定义。- 参数
size (int...) – 定义输出
DTensor
形状的整数序列。可以是可变数量的参数或列表、元组等集合。例如:ones(1,2,3..) 或 ones([1,2,3..]) 或 ones((1,2,3..))- 关键字参数
dtype (
torch.dtype
, optional) – 返回的DTensor
所需的数据类型。默认值:如果为None
,则使用全局默认值(参见torch.set_default_dtype()
)。layout (
torch.layout
, optional) – 返回的 DTensor 所需的布局。默认值:torch.strided
。requires_grad (bool, optional) – 如果 autograd 应该记录返回的
DTensor
上的操作。默认值:False
。device_mesh –
DeviceMesh
类型,包含进程的 mesh 信息。placements –
Placement
类型的序列:Shard
、Replicate
- 返回值
每个进程上的
DTensor
对象- 返回类型
- torch.distributed.tensor.randn(*size, requires_grad=False, dtype=None, layout=torch.strided, device_mesh=None, placements=None)[source]¶
返回一个
DTensor
,其中填充了来自均值为 0、方差为 1 的正态分布的随机数。张量的形状由可变参数size
定义。- 参数
size (int...) – 定义输出
DTensor
形状的整数序列。可以是可变数量的参数或列表、元组等集合。例如:ones(1,2,3..) 或 ones([1,2,3..]) 或 ones((1,2,3..))- 关键字参数
dtype (
torch.dtype
, optional) – 返回的DTensor
所需的数据类型。默认值:如果为None
,则使用全局默认值(参见torch.set_default_dtype()
)。layout (
torch.layout
, optional) – 返回的 DTensor 所需的布局。默认值:torch.strided
。requires_grad (bool, optional) – 如果 autograd 应该记录返回的
DTensor
上的操作。默认值:False
。device_mesh –
DeviceMesh
类型,包含进程的 mesh 信息。placements –
Placement
类型的序列:Shard
、Replicate
- 返回值
每个进程上的
DTensor
对象- 返回类型
调试¶
日志记录¶
启动程序时,可以使用来自 torch._logging 的 TORCH_LOGS 环境变量开启附加日志记录。
TORCH_LOGS=+dtensor 将显示 logging.DEBUG 级别的消息及其以上的所有级别。
TORCH_LOGS=dtensor 将显示 logging.INFO 级别的消息及其以上级别。
TORCH_LOGS=-dtensor 将显示 logging.WARNING 级别的消息及其以上级别。
调试工具¶
为了调试应用了 DTensor 的程序,并了解底层发生了哪些集合通信操作的更多细节,DTensor 提供了一个 CommDebugMode
。
- class torch.distributed.tensor.debug.CommDebugMode¶
CommDebugMode
是一个上下文管理器,用于计算其上下文中功能性集合通信操作的数量。它通过使用TorchDispatchMode
来实现此功能。注意
并非所有集合通信操作都已支持。
示例用法
mod = ... comm_mode = CommDebugMode() with comm_mode: mod.sum().backward() print(comm_mode.get_comm_counts())
- generate_comm_debug_tracing_table(noise_level=3)[source][source]¶
生成详细表格,显示模块级别的操作和集合通信跟踪信息。信息量取决于 noise_level。
打印模块级别的集合通信计数
打印未包含在平凡操作中的 DTensor 操作,模块信息
打印未包含在平凡操作中的操作
打印所有操作
为了可视化维度少于 3 的 DTensor 的分片 (sharding),DTensor 提供了 visualize_sharding()
。
实验性功能¶
DTensor
还提供了一系列实验性功能。这些功能要么处于原型阶段,要么基本功能已完成但正在寻求用户反馈。如果您对这些功能有反馈,请向 PyTorch 提交一个 issue。
- torch.distributed.tensor.experimental.context_parallel(mesh, *, buffers=None, buffer_seq_dims=None, no_restore_buffers=None)[source]¶
context_parallel
是一个实验性 API,用于启用上下文并行 (CP)。此 API 执行两个操作:1) 使用启用 CP 的版本修补 SDPA (torch.nn.functional.scaled_dot_product_attention
),2) 沿序列维度对buffers
进行分片,并且每个 rank 将根据mesh
保留相应的分片。- 参数
mesh (
DeviceMesh
) – 用于上下文并行的设备网格。buffers (Optional[List[torch.Tensor]]) – 使用依赖于序列维度的缓冲区。例如输入批量、标签和位置嵌入缓冲区。这些缓冲区必须沿序列维度进行分片以确保准确性。分片将原地进行,缓冲区的形状将在上下文中改变。缓冲区将在上下文结束后恢复。
no_restore_buffers
可用于指定哪些缓冲区不需要恢复。请注意,buffers
不应包含任何 nn.Parameter。buffer_seq_dims (Optional[List[int]]) –
buffers
的序列维度。no_restore_buffers (Optional[Set[torch.Tensor]]) – 此集合中的缓冲区在上下文退出后不会被恢复。此集合必须是
buffers
的子集。如果在上下文退出后不再使用这些缓冲区,可以将它们放在此列表中以避免额外的恢复时间。
- 返回类型
Generator[None, None, None]
警告
torch.distributed._tensor.experimental.attention.context_parallel 是 PyTorch 中的一个原型功能。该 API 可能随时更改。
- torch.distributed.tensor.experimental.local_map(func, out_placements, in_placements=None, device_mesh=None, *, redistribute_inputs=False)[source]¶
local_map()
是一个实验性 API,允许用户将DTensor
传递给一个旨在应用于torch.Tensor
的函数。其实现方式是提取DTensor
的本地分量,调用函数,然后根据out_placements
将输出包装成DTensor
。- 参数
func (Callable) – 应用于
DTensor
的每个本地分片上的函数。out_placements (Union[PlacementType, Tuple[PlacementType, …]]) –
func
的展平输出中DTensor
的期望布局。如果展平output
是单个值,则out_placements
应为 PlacementType 类型。否则,如果展平output
有多个值,out_placements
应为 PlacementType 值的元组,与展平output
一一对应。此外,对于Tensor
输出,我们使用 PlacementType 作为其布局(一个 Tuple[Placement] 值)。对于非 Tensor 输出,PlacementType 应为 None。请注意,唯一的例外是没有传入DTensor
参数时。在这种情况下,即使 out_placements 不是 None,结果函数也应忽略期望的布局,因为该函数不是在DTensor
环境下运行的。in_placements (Tuple[PlacementType, …], optional) –
func
的展平输入中DTensor
的所需布局。如果指定了in_placements
,local_map()
将检查每个DTensor
参数的布局是否与所需布局相同。如果布局不同且redistribute_inputs
为False
,将引发异常。否则,如果redistribute_inputs
为True
,参数将首先被重新分发到所需的 sharding 布局,然后才将其本地张量传递给func
。唯一的例外是所需布局不为None
且参数是torch.Tensor
。在这种情况下,将跳过布局检查,参数将直接传递给func
。如果in_placements
为None
,将不执行布局检查。默认值:Nonedevice_mesh (
DeviceMesh
, optional) – 所有DTensor
所在的设备网格。如果未指定,将从输入DTensor
的设备网格中推断。local_map 要求所有DTensor
位于同一个设备网格上。默认值:None。redistribute_inputs (bool, optional) – 布尔值,指示当输入
DTensor
的布局与所需输入布局不同时是否重新分片。如果此值为False
并且某个DTensor
输入具有不同的布局,将引发异常。默认值:False。
- 返回值
一个
Callable
,它将func
应用于输入DTensor
的每个本地分片,并返回由func
返回值构建的DTensor
。- 引发
AssertionError – 如果输入
DTensor
不在同一个设备网格上,或者如果它们所在的设备网格与传入的device_mesh
参数不同。AssertionError – 对于任何非 DTensor 输出,我们要求其在
out_placements
中对应的输出布局为 None。如果不是这种情况,将引发 AssertionError。ValueError – 如果
redistribute_inputs=False
但输入DTensor
需要根据in_placements
进行重新分发。
示例
>>> def mm_allreduce_forward(device_mesh, W, X): >>> partial_sum_tensor = torch.mm(W, X) >>> reduced_tensor = funcol.all_reduce(partial_sum_tensor, "sum", device_mesh) >>> return reduced_tensor >>> >>> W = torch.randn(12, 8, requires_grad=False) >>> X = torch.randn(8, 16, requires_grad=False) >>> Y = torch.mm(W, X) >>> row_wise = [Shard(0)] # row-wise sharding placements on 1-d mesh >>> col_wise = [Shard(1)] # col-wise sharding placements on 1-d mesh >>> >>> # local_mm_allreduce_forward is the function wrapped with DTensor/Tensor convertion >>> local_mm_allreduce_forward = local_map( >>> mm_allreduce_forward, >>> out_placements=[Replicate()], >>> in_placements=[col_wise, row_wise], >>> device_mesh=device_mesh, >>> ) >>> >>> W_dt = distribute_tensor( ... W, device_mesh, (col_wise) ... ) # col-wisely sharded W tensor >>> X_dt = distribute_tensor( ... X, device_mesh, (row_wise) ... ) # row-wisely sharded X tensor >>> Y_dt = local_mm_allreduce_forward( ... device_mesh, W_dt, X_dt ... ) # apply local_mm_allreduce_forward to DTensors
注意
此 API 目前是实验性的,可能会随时更改
- torch.distributed.tensor.experimental.register_sharding(op)[source]¶
register_sharding()
是一个实验性 API,允许用户在张量输入和输出是 DTensor 时为运算符注册分片策略。在以下情况下可能很有用:(1) 对于op
不存在默认分片策略,例如当op
是DTensor
不支持的自定义运算符时;(2) 当用户想要覆盖现有运算符的默认分片策略时。- 参数
op (Union[OpOverload, List[OpOverload]]) – 要注册自定义分片函数的运算符或运算符列表。
- 返回值
一个函数装饰器,可用于包装一个函数,该函数定义了
op
中指定运算符的分片策略。定义的分片策略将注册到 DTensor,如果 DTensor 已实现了该运算符,则会覆盖默认分片策略。定制的分片函数接受与原始op
相同的输入(除了如果参数是torch.Tensor
,它将被替换为 DTensor 内部使用的张量状对象)。该函数应返回一个 2 元组序列,每个元组指定可接受的输出布局及其对应的输入布局。
示例
>>> @register_sharding(aten._softmax.default) >>> def custom_softmax_sharding(x, dim, half_to_float): >>> softmax_dim = dim if dim >= 0 else dim + x.ndim >>> acceptable_shardings = [] >>> >>> all_replicate = ([Replicate()], [Replicate(), None, None]) >>> acceptable_shardings.append(all_replicate) >>> >>> for sharding_dim in range(x.ndim): >>> if sharding_dim != softmax_dim: >>> all_sharded = ( >>> [Shard(sharding_dim)], >>> [Shard(sharding_dim), None, None], >>> ) >>> acceptable_shardings.append(all_sharded) >>> >>> return acceptable_shardings
注意
此 API 目前是实验性的,可能会随时更改