torch.distributed.tensor¶
注意
torch.distributed.tensor 目前处于 alpha 状态,正在开发中,我们承诺对文档中列出的大多数 API 保持向后兼容性,但如果有必要,可能会更改 API。
PyTorch DTensor (分布式张量)¶
PyTorch DTensor 提供简单灵活的张量分片原语,透明地处理分布式逻辑,包括分片存储、运算符计算和跨设备/主机的集体通信。 DTensor 可用于构建不同的并行解决方案,并在使用多维分片时支持分片 state_dict 表示。
请查看 PyTorch 原生并行解决方案的示例,这些解决方案构建在 DTensor 之上
DTensor 遵循 SPMD (单程序多数据) 编程模型,使用户能够编写分布式程序,就好像它是具有相同收敛特性的**单设备程序**一样。它通过指定 DeviceMesh 和 Placement 提供统一的张量分片布局 (DTensor 布局)
DeviceMesh使用 n 维数组表示集群的设备拓扑和通信器。Placement描述了逻辑张量在DeviceMesh上的分片布局。DTensor 支持三种类型的放置:Shard、Replicate和Partial。
DTensor 类 API¶
DTensor 是 torch.Tensor 的子类。这意味着一旦创建了 DTensor,它可以用与 torch.Tensor 非常相似的方式使用,包括运行不同类型的 PyTorch 运算符,就好像在单个设备上运行一样,允许为 PyTorch 运算符进行适当的分布式计算。
除了现有的 torch.Tensor 方法之外,它还提供了一组附加方法来与 torch.Tensor 交互,redistribute DTensor 布局到新的 DTensor,获取所有设备上的完整张量内容等等。
- class torch.distributed.tensor.DTensor(local_tensor, spec, *, requires_grad)¶
DTensor(分布式张量) 是torch.Tensor的子类,它提供单设备般的抽象来使用多设备torch.Tensor。它通过DeviceMesh和以下类型的Placement描述分布式张量分片布局 (DTensor 布局)Shard: 在DeviceMesh维度的设备上,在张量维度dim上分片的张量Replicate: 在DeviceMesh维度的设备上复制的张量Partial: 张量正在等待在DeviceMesh维度的设备上进行约简
在调用 PyTorch 算子时,
DTensor会覆盖 PyTorch 算子以执行分片计算并在必要时发出通信。 除了算子计算外,DTensor会根据算子语义本身,正确转换或传播位置(DTensor 布局),并生成新的DTensor输出。为了确保调用 PyTorch 算子时
DTensor分片计算的数值正确性,DTensor要求算子的每个 Tensor 参数都是 DTensor。- 返回值类型
- property device_mesh: DeviceMesh¶
与该 DTensor 对象关联的
DeviceMesh属性。注意
device_mesh是一个只读属性,不能设置。
- static from_local(local_tensor, device_mesh=None, placements=None, *, run_check=False, shape=None, stride=None)[source]¶
根据指定的
device_mesh和placements,从每个 rank 上的本地 torch.Tensor 创建一个DTensor。- 参数
local_tensor (torch.Tensor) – 每个 rank 上的本地 torch.Tensor。
device_mesh (
DeviceMesh, optional) – 用于放置张量的 DeviceMesh,如果未指定,则必须在 DeviceMesh 上下文管理器下调用,默认值:Noneplacements (List[
Placement], optional) – 描述如何在 DeviceMesh 上放置本地 torch.Tensor 的位置,必须与device_mesh.ndim的元素数量相同。
- 关键字参数
run_check (bool, optional) – 以额外通信为代价,在所有 rank 之间执行健全性检查,以检查每个本地张量的元信息以确保正确性。 如果在
placements中有Replicate,设备网格维度上第一个 rank 上的数据将广播到其他 rank。 默认值:Falseshape (torch.Size, optional) – 一个整数列表,指定在 local_tensor 之上构建的 DTensor 的大小。 请注意,如果
local_tensor的形状在所有 rank 上不同,则需要提供此参数。 如果未提供,则会假设给定的分布式张量在所有 rank 上均匀分片,并计算shape。 默认值:Nonestride (tuple, optional) – 一个整数列表,指定 DTensor 的步长。 如果未提供,则会假设给定的分布式张量在所有 rank 上均匀分片,并计算
stride。 默认值:None
- 返回
一个
DTensor对象- 返回值类型
注意
当
run_check=False时,用户有责任确保传入的本地张量在所有 rank 上都正确(即,张量对于Shard(dim)位置被分片,或者对于Replicate()位置被复制)。 如果没有,则创建的 DTensor 的行为未定义。注意
from_local是可微分的,创建的 DTensor 对象的 requires_grad 将取决于 local_tensor 是否需要梯度。
- full_tensor(*, grad_placements=None)[source]¶
返回此 DTensor 的完整张量。 它将执行必要的集体操作,从其 DeviceMesh 中的其他 rank 收集本地张量,并将它们串联起来。 这是以下代码的语法糖
dtensor.redistribute(placements=[Replicate()] * mesh.ndim).to_local()- 关键字参数
grad_placements (List[
Placement], optional) – 位置描述从此函数返回的完整张量的任何梯度布局的未来布局。 full_tensor 将 DTensor 转换为完整的 torch.Tensor,返回的 torch.tensor 可能不会在以后的代码中用作原始复制的 DTensor 布局。 此参数是用户可以提供给 autograd 的提示,以防返回的张量的梯度布局与原始复制的 DTensor 布局不匹配。 如果未指定,我们将假设完整张量的梯度布局为复制的。- 返回
一个
torch.Tensor对象,表示此 DTensor 的完整张量。- 返回值类型
注意
full_tensor是可微分的。
- property placements: Tuple[Placement, ...]¶
此 DTensor 的位置属性,描述了此 DTensor 在其 DeviceMesh 上的布局。
注意
placements是一个只读属性,不能设置。
- redistribute(device_mesh=None, placements=None, *, async_op=False)[source]¶
redistribute执行必要的集体操作,将当前 DTensor 从其当前位置重新分布到新的位置,或者从其当前 DeviceMesh 重新分布到新的 DeviceMesh。 也就是说,我们可以通过为 DeviceMesh 的每个维度指定一个 Replicate 位置来将一个分片的 DTensor 转换为一个复制的 DTensor。当在同一个设备网格维度上从当前位置重新分布到新的位置时,我们将执行以下操作,包括通信集体操作或本地操作
Shard(dim)->Replicate():all_gatherShard(src_dim)->Shard(dst_dim):all_to_allReplicate()->Shard(dim): 本地分块(即torch.chunk)Partial()->Replicate():all_reducePartial()->Shard(dim):reduce_scatter
redistribute将会正确地确定在 1-D 或 N-D DeviceMesh 上创建的 DTensor 的必要重新分布步骤。- 参数
device_mesh (
DeviceMesh, optional) – 用于放置 DTensor 的 DeviceMesh。 如果未指定,它将使用当前 DTensor 的 DeviceMesh。 默认值:Noneplacements (List[
Placement], optional) – 描述如何将 DTensor 放置到 DeviceMesh 中的新位置,必须与device_mesh.ndim的元素数量相同。 默认值:在所有网格维度上复制
- 关键字参数
async_op (bool, optional) – 是否异步执行 DTensor 重新分布操作。 默认值:False
- 返回
一个
DTensor对象- 返回值类型
注意
redistribute是可微分的,这意味着用户不必担心重新分布操作的反向公式。注意
redistribute目前只支持在同一个 DeviceMesh 上重新分布 DTensor,如果您需要将 DTensor 重新分布到不同的 DeviceMesh,请提交一个问题。
- to_local(*, grad_placements=None)[source]¶
获取此 DTensor 在其当前秩上的本地张量。对于分片,它返回逻辑张量视图的本地分片;对于复制,它返回其当前秩上的副本。
- 关键字参数
grad_placements (List[
Placement], 可选) – 这些布局描述了从此函数返回的张量的任何梯度布局的未来布局。to_local 将 DTensor 转换为本地张量,并且返回的本地张量可能不会在代码中的后续部分中用作原始 DTensor 布局。此参数是用户可以提供给自动梯度的提示,以防返回的张量的梯度布局与原始 DTensor 布局不匹配。如果未指定,我们将假设梯度布局与原始 DTensor 相同,并使用该布局进行梯度计算。- 返回
一个
torch.Tensor或AsyncCollectiveTensor对象。它表示其当前秩上的本地张量。当返回AsyncCollectiveTensor对象时,表示本地张量尚未准备就绪(即通信尚未完成)。在这种情况下,用户需要调用wait来等待本地张量准备就绪。- 返回值类型
注意
to_local是可微分的,返回的本地张量的requires_grad将取决于 DTensor 是否需要梯度。
DeviceMesh 作为分布式通信器¶
DeviceMesh 是从 DTensor 构建的,作为描述集群设备拓扑和表示多维通信器(在 ProcessGroup 之上)的抽象。有关如何创建/使用 DeviceMesh 的详细信息,请参阅 DeviceMesh 食谱。
DTensor 布局类型¶
DTensor 支持在每个 DeviceMesh 维度上以下类型的 Placement
- class torch.distributed.tensor.placement_types.Shard(dim)[source]¶
Shard(dim)布局描述了 DTensor 在张量维度dim上跨对应DeviceMesh维度的分片,其中 DeviceMesh 维度上的每个秩仅保留全局张量的碎片/部分。Shard(dim)布局遵循torch.chunk(dim)语义,其中 DeviceMesh 维度上的最后几个碎片可能为空,当张量维度不能被 DeviceMesh 维度整除时。Shard布局可被所有 DTensor API 使用(即 distribute_tensor、from_local 等)。- 参数
dim (int) – 描述 DTensor 在其对应 DeviceMesh 维度上分片的张量维度。
警告
在张量维度上进行分片,其中张量维度大小不能被 DeviceMesh 维度整除,目前处于实验阶段,可能会发生变化。
- class torch.distributed.tensor.placement_types.Replicate[source]¶
Replicate()布局描述了 DTensor 在对应DeviceMesh维度上的复制,其中 DeviceMesh 维度上的每个秩都保留全局张量的副本。Replicate布局可被所有 DTensor API 使用(即distribute_tensor、DTensor.from_local等)。
- class torch.distributed.tensor.placement_types.Partial(reduce_op='sum')[source]¶
Partial(reduce_op)布局描述了 DTensor 在指定DeviceMesh维度上等待约简,其中 DeviceMesh 维度上的每个秩都保留全局张量的部分值。用户可以使用redistribute将PartialDTensor 重新分布到指定DeviceMesh维度上的Replicate或Shard(dim)布局,这将在后台触发必要的通信操作(即allreduce、reduce_scatter)。- 参数
reduce_op (str, 可选) – 用于产生 Replicated/Sharded DTensor 的部分 DTensor 的约简操作。仅支持按元素约简操作,包括:“sum”、“avg”、“product”、“max”、“min”,默认值为:“sum”。
注意
Partial布局可以作为 DTensor 运算符的结果生成,并且只能由DTensor.from_localAPI 使用。
创建 DTensor 的不同方法¶
- 有三种方法可以构建一个
DTensor distribute_tensor()从每个秩上的逻辑或“全局”torch.Tensor创建DTensor。这可用于对叶子torch.Tensors(即模型参数/缓冲区和输入)进行分片。DTensor.from_local()从每个秩上的本地torch.Tensor创建DTensor,这可用于从非叶子torch.Tensors(即正向/反向传播期间的中间激活张量)创建DTensor。DTensor 提供专用的张量工厂函数(例如
empty()、ones()、randn()等),通过直接指定DeviceMesh和Placement来创建不同的DTensor。与distribute_tensor()相比,这可以将分片内存直接实例化到设备上,而不是在初始化逻辑张量内存后执行分片。
从逻辑 torch.Tensor 创建 DTensor¶
torch.distributed 中的 SPMD(单程序多数据)编程模型启动多个进程(例如,通过 torchrun)来执行同一个程序,这意味着程序中的模型会首先在不同的进程上初始化(例如,模型可能在 CPU、元设备上初始化,或者如果内存足够,直接在 GPU 上初始化)。
DTensor 提供了一个 distribute_tensor() API,可以将模型权重或张量分片到 DTensor 中,它将从每个进程上的“逻辑”张量创建一个 DTensor。这将使创建的 DTensor 能够符合单设备语义,这对于 **数值正确性** 至关重要。
- torch.distributed.tensor.distribute_tensor(tensor, device_mesh=None, placements=None)¶
将叶
torch.Tensor(例如 nn.Parameter/缓冲区)根据指定的placements分布到device_mesh中。device_mesh和placements的秩必须相同。要分布的tensor是逻辑或“全局”张量,API 将使用 DeviceMesh 维度第一个秩的tensor作为真实来源以保留单设备语义。如果要在 Autograd 计算的中间构建 DTensor,请使用DTensor.from_local()代替。- 参数
tensor (torch.Tensor) – 要分布的 torch.Tensor。请注意,如果您想在不能被网格维度中设备数量整除的维度上分片张量,我们将使用
torch.chunk语义来分片张量并散布分片。不均匀分片行为是实验性的,可能会发生变化。device_mesh (
DeviceMesh, optional) – 用于分布张量的 DeviceMesh,如果未指定,则必须在 DeviceMesh 上下文管理器下调用,默认值为:Noneplacements (List[
Placement], optional) – 描述如何在 DeviceMesh 上放置张量的放置位置,必须与device_mesh.ndim具有相同数量的元素。如果未指定,我们将默认在 device_mesh 的每个维度的第一个秩上复制张量。
- 返回
一个
DTensor或XLAShardedTensor对象。- 返回值类型
注意
当使用
xladevice_type 初始化 DeviceMesh 时,distribute_tensor会返回 XLAShardedTensor。有关更多详细信息,请参阅 此问题。XLA 集成是实验性的,可能会发生变化。
除了 distribute_tensor() 之外,DTensor 还提供了一个 distribute_module() API,以允许更容易地在 nn.Module 级别上进行分片
- torch.distributed.tensor.distribute_module(module, device_mesh=None, partition_fn=None, input_fn=None, output_fn=None)¶
此函数公开三个函数来控制模块的参数/输入/输出
1. 通过指定
partition_fn来在运行时执行之前对模块执行分片(例如,允许用户根据指定的 partition_fn 将 Module 参数转换为DTensor参数)。2. 通过指定input_fn和output_fn来控制模块在运行时执行期间的输入或输出。(例如,将输入转换为DTensor,将输出转换回torch.Tensor)- 参数
module (
nn.Module) – 要进行分片的用户模块。device_mesh (
DeviceMesh) – 用于放置模块的设备网格。partition_fn (Callable) – 用于对参数进行分片的函数(例如,在整个
device_mesh上分片某些参数)。如果未指定partition_fn,默认情况下,我们将复制module的所有模块参数到网格中。input_fn (Callable) – 指定输入分布,例如,可以控制模块的输入如何分片。
input_fn将被安装为模块forward_pre_hook(前向挂钩)。output_fn (Callable) – 指定输出分布,例如,可以控制输出如何分片,或者将其转换回 torch.Tensor。
output_fn将被安装为模块forward_hook(后向挂钩)。
- 返回
一个包含所有为
DTensor的参数/缓冲区的模块。- 返回值类型
注意
当使用
xladevice_type 初始化 DeviceMesh 时,distribute_module会返回带有 PyTorch/XLA SPMD 注解参数的 nn.Module。有关更多详细信息,请参阅 此问题。XLA 集成是实验性的,可能会发生变化。
DTensor 工厂函数¶
DTensor 还提供专用的张量工厂函数,以允许通过使用类似于 torch.Tensor 的工厂函数 API(例如 torch.ones、torch.empty 等)直接创建 DTensor,同时还要为创建的 DTensor 指定 DeviceMesh 和 Placement。
- torch.distributed.tensor.zeros(*size, requires_grad=False, dtype=None, layout=torch.strided, device_mesh=None, placements=None)¶
返回一个填充标量值 0 的
DTensor。- 参数
size (int...) – 定义输出
DTensor形状的一系列整数。可以是可变数量的参数或像列表或元组这样的集合。例如:zeros(1,2,3..) 或 zeros([1,2,3..]) 或 zeros((1,2,3..))- 关键字参数
requires_grad (bool, optional) – 如果 autograd 应该记录返回的
DTensor上的操作。默认值:False。dtype (
torch.dtype, 可选) – 返回的DTensor的所需数据类型。默认:如果为None,则使用全局默认值(参见torch.set_default_dtype())。layout (
torch.layout, 可选) – 返回的DTensor的所需布局。默认:torch.strided。device_mesh –
DeviceMesh类型,包含等级的网格信息placements –
Placement类型序列:Shard,Replicate
- 返回
每个等级上的一个
DTensor对象- 返回值类型
- torch.distributed.tensor.ones(*size, dtype=None, layout=torch.strided, requires_grad=False, device_mesh=None, placements=None)¶
返回一个用标量值 1 填充的
DTensor,其形状由可变参数size定义。- 参数
size (int...) – 定义输出
DTensor形状的整数序列。可以是可变数量的参数,也可以是集合,如列表或元组。例如:ones(1,2,3..) 或 ones([1,2,3..]) 或 ones((1,2,3..))- 关键字参数
dtype (
torch.dtype, 可选) – 返回的DTensor的所需数据类型。默认:如果为None,则使用全局默认值(参见torch.set_default_dtype())。layout (
torch.layout, 可选) – 返回的 DTensor 的所需布局。默认:torch.strided。requires_grad (bool, optional) – 如果 autograd 应该记录返回的
DTensor上的操作。默认值:False。device_mesh –
DeviceMesh类型,包含等级的网格信息placements –
Placement类型序列:Shard,Replicate
- 返回
每个等级上的一个
DTensor对象- 返回值类型
- torch.distributed.tensor.empty(*size, dtype=None, layout=torch.strided, requires_grad=False, device_mesh=None, placements=None)¶
返回一个用未初始化数据填充的
DTensor。DTensor的形状由可变参数size定义。- 参数
size (int...) – 定义输出
DTensor形状的整数序列。可以是可变数量的参数,也可以是集合,如列表或元组。例如:empty(1,2,3..) 或 empty([1,2,3..]) 或 empty((1,2,3..))- 关键字参数
dtype (
torch.dtype, 可选) – 返回的DTensor的所需数据类型。默认:如果为None,则使用全局默认值(参见torch.set_default_dtype())。 layout (torch.layout, 可选): 返回的DTensor的所需布局。默认:torch.strided。requires_grad (bool, optional) – 如果 autograd 应该记录返回的
DTensor上的操作。默认值:False。device_mesh –
DeviceMesh类型,包含等级的网格信息placements –
Placement类型序列:Shard,Replicate
- 返回
每个等级上的一个
DTensor对象- 返回值类型
- torch.distributed.tensor.full(size, fill_value, *, dtype=None, layout=torch.strided, requires_grad=False, device_mesh=None, placements=None)¶
返回一个根据
device_mesh和placements用fill_value填充的DTensor,其形状由参数size定义。- 参数
- 关键字参数
dtype (
torch.dtype, 可选) – 返回的DTensor的所需数据类型。默认:如果为None,则使用全局默认值(参见torch.set_default_dtype())。layout (
torch.layout, 可选) – 返回的 DTensor 的所需布局。默认:torch.strided。requires_grad (bool, optional) – 如果 autograd 应该记录返回的
DTensor上的操作。默认值:False。device_mesh –
DeviceMesh类型,包含等级的网格信息。placements –
Placement类型序列:Shard,Replicate
- 返回
每个等级上的一个
DTensor对象- 返回值类型
- torch.distributed.tensor.rand(*size, requires_grad=False, dtype=None, layout=torch.strided, device_mesh=None, placements=None)¶
返回一个用来自区间
[0, 1)上的均匀分布的随机数填充的DTensor。张量的形状由可变参数size定义。- 参数
size (int...) – 定义输出
DTensor形状的整数序列。可以是可变数量的参数,也可以是集合,如列表或元组。例如:ones(1,2,3..) 或 ones([1,2,3..]) 或 ones((1,2,3..))- 关键字参数
dtype (
torch.dtype, 可选) – 返回的DTensor的所需数据类型。默认:如果为None,则使用全局默认值(参见torch.set_default_dtype())。layout (
torch.layout, 可选) – 返回的 DTensor 的所需布局。默认:torch.strided。requires_grad (bool, optional) – 如果 autograd 应该记录返回的
DTensor上的操作。默认值:False。device_mesh –
DeviceMesh类型,包含等级的网格信息。placements –
Placement类型序列:Shard,Replicate
- 返回
每个等级上的一个
DTensor对象- 返回值类型
- torch.distributed.tensor.randn(*size, requires_grad=False, dtype=None, layout=torch.strided, device_mesh=None, placements=None)¶
返回一个用来自均值为 0 且方差为 1 的正态分布的随机数填充的
DTensor。张量的形状由可变参数size定义。- 参数
size (int...) – 定义输出
DTensor形状的整数序列。可以是可变数量的参数,也可以是集合,如列表或元组。例如:ones(1,2,3..) 或 ones([1,2,3..]) 或 ones((1,2,3..))- 关键字参数
dtype (
torch.dtype, 可选) – 返回的DTensor的所需数据类型。默认:如果为None,则使用全局默认值(参见torch.set_default_dtype())。layout (
torch.layout, 可选) – 返回的 DTensor 的所需布局。默认:torch.strided。requires_grad (bool, optional) – 如果 autograd 应该记录返回的
DTensor上的操作。默认值:False。device_mesh –
DeviceMesh类型,包含等级的网格信息。placements –
Placement类型序列:Shard,Replicate
- 返回
每个等级上的一个
DTensor对象- 返回值类型
调试¶
日志记录¶
启动程序时,可以使用来自 torch._logging 的 TORCH_LOGS 环境变量打开额外的日志记录
TORCH_LOGS=+dtensor 将显示 logging.DEBUG 消息及其以上所有级别。
TORCH_LOGS=dtensor 将显示 logging.INFO 消息及其以上。
TORCH_LOGS=-dtensor 将显示 logging.WARNING 消息及其以上。
调试工具¶
为了调试应用了 DTensor 的程序,并更详细地了解底层发生的集体操作,DTensor 提供了一个 CommDebugMode
- class torch.distributed.tensor.debug.CommDebugMode¶
CommDebugMode是一个上下文管理器,用于统计其上下文内的功能性集体操作数量。它使用TorchDispatchMode来实现。示例用法
mod = ... comm_mode = CommDebugMode() with comm_mode: mod.sum().backward() print(comm_mode.get_comm_counts())
- generate_comm_debug_tracing_table(noise_level=3)[source]¶
生成详细的表格,显示模块级别的操作和集体跟踪信息。信息量取决于噪声等级。
打印模块级别的集体计数
打印不包含在简单操作中的 dTensor 操作,模块信息
打印不包含在简单操作中的操作
打印所有操作
为了可视化维度小于 3 的 DTensor 的分片,DTensor 提供了 visualize_sharding()
- torch.distributed.tensor.debug.visualize_sharding(dtensor, header='')¶
在终端中可视化
DTensor的分片,这些DTensor是 1D 或 2D。注意
这需要
tabulate包。对于空张量,不会打印分片信息。
实验性功能¶
DTensor 还提供了一组实验性功能。这些功能要么处于原型设计阶段,要么基本功能已经完成,但正在寻求用户反馈。如果您对这些功能有任何反馈,请向 PyTorch 提交问题。
- torch.distributed.tensor.experimental.local_map(func, out_placements, in_placements=None, device_mesh=None, *, redistribute_inputs=False)¶
local_map()是一个实验性 API,允许用户将DTensor传递给一个函数,该函数被编写为应用于torch.Tensor。它是通过提取DTensor的本地组件,调用函数,并根据out_placements将输出包装到DTensor来完成的。- 参数
func (Callable) – 将应用于
DTensor的每个本地分片的函数。out_placements (Union[PlacementType, Tuple[PlacementType, …]]) –
func平铺输出中DTensor的所需放置位置。如果平铺的output是单个值,则out_placements应该为 PlacementType 类型。否则,如果平铺的output具有多个值,则out_placements应该是一个 PlacementType 值的元组,与平铺的output一一对应。此外,对于Tensor输出,我们使用 PlacementType 作为其放置位置(一个 Tuple[Placement] 值)。对于非张量输出,PlacementType 应该为 None。需要注意的是,唯一例外是当没有DTensor参数传入时。在这种情况下,即使 out_placements 不是 None,结果函数也应该忽略所需的放置位置,因为函数不是与DTensor一起运行的。in_placements (Tuple[PlacementType, …], optional) –
func平铺输入中DTensor的所需放置位置。如果指定了in_placements,则local_map()会检查每个DTensor参数的放置位置是否与其所需放置位置相同。如果放置位置不同,并且redistribute_inputs为False,则会抛出异常。否则,如果redistribute_inputs为True,则该参数将在传递其本地张量到func之前,首先被重新分配到所需的分片放置位置。唯一例外是,当所需放置位置不为None并且参数是torch.Tensor时。在这种情况下,会跳过放置位置检查,并且参数将直接传递给func。如果in_placements为None,则不会执行任何放置位置检查。默认值:Nonedevice_mesh (
DeviceMesh, optional) – 所有DTensor所放置的设备网格。如果没有指定,它将从输入DTensor的设备网格中推断出来。 local_map 要求每个DTensor都放置在同一个设备网格上。默认值:None。redistribute_inputs (bool, optional) – 布尔值,指示当输入
DTensor的放置位置与其所需输入放置位置不同时,是否要重新分片输入DTensor。如果该值为False并且某些DTensor输入具有不同的放置位置,则会抛出异常。默认值:False。
- 返回
一个
Callable,它将func应用于输入DTensor的每个本地分片,并返回一个由func返回值构建的DTensor。- 引发
AssertionError – 如果输入
DTensor未放置在同一设备网格上,或者它们放置在与传递的device_mesh参数不同的设备网格上。AssertionError – 对于任何非 DTensor 输出,我们要求其在
out_placements中的相应输出放置为 None。如果不满足此条件,将引发 AssertionError。ValueError – 如果
redistribute_inputs=False但输入DTensor需要根据in_placements进行重新分配。
示例
>>> def mm_allreduce_forward(device_mesh, W, X): >>> partial_sum_tensor = torch.mm(W, X) >>> reduced_tensor = funcol.all_reduce(partial_sum_tensor, "sum", device_mesh) >>> return reduced_tensor >>> >>> W = torch.randn(12, 8, requires_grad=False) >>> X = torch.randn(8, 16, requires_grad=False) >>> Y = torch.mm(W, X) >>> row_wise = [Shard(0)] # row-wise sharding placements on 1-d mesh >>> col_wise = [Shard(1)] # col-wise sharding placements on 1-d mesh >>> >>> # local_mm_allreduce_forward is the function wrapped with DTensor/Tensor convertion >>> local_mm_allreduce_forward = local_map( >>> mm_allreduce_forward, >>> out_placements=[Replicate()], >>> in_placements=[col_wise, row_wise], >>> device_mesh=device_mesh, >>> ) >>> >>> W_dt = distribute_tensor(W, device_mesh, (col_wise)) # col-wisely sharded W tensor >>> X_dt = distribute_tensor(X, device_mesh, (row_wise)) # row-wisely sharded X tensor >>> Y_dt = local_mm_allreduce_forward(device_mesh, W_dt, X_dt) # apply local_mm_allreduce_forward to DTensors
注意
此 API 目前处于实验阶段,可能会发生变化。
- torch.distributed.tensor.experimental.register_sharding(op)¶
register_sharding()是一个实验性 API,允许用户在张量输入和输出为 DTensor 时为运算符注册分片策略。当以下情况时,它可能很有用:(1)op不存在默认的分片策略,例如,当op是DTensor不支持的自定义运算符时;(2) 当用户想要覆盖现有运算符的默认分片策略时。- 参数
op (Union[OpOverload, List[OpOverload]]) – 要注册自定义分片函数的运算符或运算符列表。
- 返回
一个函数装饰器,可用于包装一个函数,该函数定义了在
op中指定的运算符的分片策略。定义的分片策略将注册到 DTensor,如果 DTensor 已经实现了该运算符,它将覆盖默认的分片策略。自定义分片函数接受与原始运算符相同的输入(除了如果参数是torch.Tensor,它将被 DTensor 在内部使用的类张量对象替换)。该函数应返回一个由 2 元组组成的序列,每个元组指定可接受的输出放置和其相应的输入放置。
示例
>>> @register_sharding(aten._softmax.default) >>> def custom_softmax_sharding(x, dim, half_to_float): >>> softmax_dim = dim if dim >= 0 else dim + x.ndim >>> acceptable_shardings = [] >>> >>> all_replicate = ([Replicate()], [Replicate(), None, None]) >>> acceptable_shardings.append(all_replicate) >>> >>> for sharding_dim in range(x.ndim): >>> if sharding_dim != softmax_dim: >>> all_sharded = ( >>> [Shard(sharding_dim)], >>> [Shard(sharding_dim), None, None], >>> ) >>> acceptable_shardings.append(all_sharded) >>> >>> return acceptable_shardings
注意
此 API 目前处于实验阶段,可能会发生变化。