张量并行 - torch.distributed.tensor.parallel¶
张量并行 (TP) 建立在 PyTorch 分布式张量 (DTensor) 之上,并提供不同的并行风格:列式并行、行式并行和序列并行。
警告
张量并行 API 处于实验阶段,可能会发生变化。
使用张量并行对您的 nn.Module
进行并行的入口点是
- torch.distributed.tensor.parallel.parallelize_module(module, device_mesh, parallelize_plan)[source]¶
通过根据用户指定的计划对模块或子模块进行并行化,在 PyTorch 中应用张量并行。
我们根据 parallelize_plan 对模块或子模块进行并行化。parallelize_plan 包含
ParallelStyle
,它指示用户希望如何对模块或子模块进行并行化。用户还可以为每个模块的完全限定名称 (FQN) 指定不同的并行风格。
请注意,
parallelize_module
只接受一个一维DeviceMesh
,如果您有一个二维或 N 维DeviceMesh
,请先将 DeviceMesh 切片为一维子 DeviceMesh,然后传递给此 API(例如device_mesh["tp"]
)。- 参数
module (
nn.Module
) – 要并行化的模块。device_mesh (
DeviceMesh
) – 描述 DTensor 设备网格拓扑的对象。parallelize_plan (Union[
ParallelStyle
, Dict[str,ParallelStyle
]]) – 用于对模块进行并行的计划。它可以是ParallelStyle
对象,其中包含我们如何为张量并行准备输入/输出,或者它可以是模块 FQN 的字典及其对应的ParallelStyle
对象。
- 返回值
一个已并行化的
nn.Module
对象。- 返回值类型
- 示例:
>>> from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel >>> from torch.distributed.device_mesh import init_device_mesh >>> >>> # Define the module. >>> m = Model(...) >>> tp_mesh = init_device_mesh("cuda", (8,)) >>> m = parallelize_module(m, tp_mesh, {"w1": ColwiseParallel(), "w2": RowwiseParallel()}) >>>
注意
对于像注意力、MLP 层这样的复杂模块架构,我们建议将不同的 ParallelStyles 组合在一起(例如
ColwiseParallel
和RowwiseParallel
),并将它们作为 parallelize_plan 传递,以实现所需的切片计算。
张量并行支持以下并行风格
- class torch.distributed.tensor.parallel.ColwiseParallel(*, input_layouts=None, output_layouts=None, use_local_output=True)[source]¶
以列式方式对兼容的 nn.Module 进行分区。目前支持 nn.Linear 和 nn.Embedding。用户可以将它与 RowwiseParallel 组合在一起,以实现对更复杂模块(例如 MLP、注意力)的切片。
- 关键字参数
input_layouts (Placement, optional) – nn.Module 的输入张量的 DTensor 布局,用于将输入张量注释为 DTensor。如果没有指定,我们假设输入张量是复制的。
output_layouts (Placement, 可选) – nn.Module 输出的 DTensor 布局,用于确保 nn.Module 的输出具有用户期望的布局。如果未指定,输出张量将在最后一个维度上进行分片。
use_local_output (布尔值, 可选) – 是否使用本地
torch.Tensor
而不是DTensor
作为模块输出,默认:True。
- 返回值
表示 nn.Module 列式分片的
ParallelStyle
对象。
- 示例:
>>> from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel >>> from torch.distributed.device_mesh import init_device_mesh >>> ... >>> m = Model(...) # m is a nn.Module that contains a "w1" nn.Linear submodule >>> tp_mesh = init_device_mesh("cuda", (8,)) >>> >>> # By default, the input of the "w1" Linear will be converted to Replicated DTensor >>> # and the output of "w1" will return :class:`torch.Tensor` that shards on the last dim. >>> >>> sharded_mod = parallelize_module(m, tp_mesh, {"w1": ColwiseParallel()}) >>> ...
注意
默认情况下,如果未指定
output_layouts
,ColwiseParallel
的输出将在最后一个维度上进行分片,如果存在需要特定张量形状的运算符(例如,在配对的RowwiseParallel
之前),请记住,如果输出是分片的,则运算符可能需要调整为分片大小。
- class torch.distributed.tensor.parallel.RowwiseParallel(*, input_layouts=None, output_layouts=None, use_local_output=True)[source]¶
以行方式对兼容的 nn.Module 进行分区。当前支持 nn.Linear 和 nn.Embedding。用户可以将其与 ColwiseParallel 组合以实现更复杂模块的分片(例如,MLP、Attention)。
- 关键字参数
input_layouts (Placement, 可选) – nn.Module 输入张量的 DTensor 布局,用于将输入张量注释为 DTensor。如果未指定,我们假设输入张量在最后一个维度上进行分片。
output_layouts (Placement, 可选) – nn.Module 输出的 DTensor 布局,用于确保 nn.Module 的输出具有用户期望的布局。如果未指定,输出张量将被复制。
use_local_output (布尔值, 可选) – 是否使用本地
torch.Tensor
而不是DTensor
作为模块输出,默认:True。
- 返回值
表示 nn.Module 行式分片的
ParallelStyle
对象。
- 示例:
>>> from torch.distributed.tensor.parallel import parallelize_module, RowwiseParallel >>> from torch.distributed.device_mesh import init_device_mesh >>> ... >>> m = Model(...) # m is a nn.Module that contains a "w2" nn.Linear submodule >>> tp_mesh = init_device_mesh("cuda", (8,)) >>> >>> # By default, the input of the "w2" Linear will be converted to DTensor that shards on the last dim >>> # and the output of "w2" will return a replicated :class:`torch.Tensor`. >>> >>> sharded_mod = parallelize_module(m, tp_mesh, {"w2": RowwiseParallel()}), >>> ...
- class torch.distributed.tensor.parallel.SequenceParallel(*, sequence_dim=1, use_local_output=False)[source]¶
SequenceParallel 复制兼容的
nn.Module
参数,并使用在序列维度上进行分片的输入运行分片计算。目前支持nn.LayerNorm
、nn.Dropout
以及 RMSNorm python 实现此样式实现了论文 Reducing Activation Recomputation in Large Transformer Models 中描述的操作。
如果传递到此
nn.Module
的输入是torch.Tensor
,则假设输入已在序列维度上进行分片,并将输入转换为在序列维度上进行分片的DTensor
。如果传递到此nn.Module
的输入已经是DTensor
但未在序列维度上进行分片,它将重新分配输入以便在序列维度上进行分片。nn.Module
的输出将在序列维度上进行分片。- 关键字参数
sequence_dim (整数, 可选) –
nn.Module
输入张量的序列维度,用于将输入张量注释为在序列维度上进行分片的 DTensor,默认:1。use_local_output (布尔值, 可选) – 是否使用本地
torch.Tensor
而不是DTensor
作为模块输出,默认:False。
- 返回值
表示
nn.Module
序列并行的ParallelStyle
对象。
- 示例:
>>> from torch.distributed.tensor.parallel import parallelize_module, SequenceParallel >>> from torch.distributed.device_mesh import init_device_mesh >>> ... >>> m = Model(...) # m is a nn.Module that contains a "norm" nn.LayerNorm submodule >>> tp_mesh = init_device_mesh("cuda", (8,)) >>> >>> # By default, the input of the "norm" will be converted to DTensor that shards on the sequence dim >>> # and the output of "norm" will return a sharded on sequence dimension :class:`DTensor`. >>> >>> sharded_mod = parallelize_module(m, tp_mesh, {"norm": SequenceParallel()}), >>> ...
注意
SequenceParallel 样式假设如果 nn.Module 中存在权重(例如
nn.LayerNorm
或RMSNorm
,它们默认情况下具有 ones 初始化),则使用 ones 初始化。如果您对这些模块上的权重有自定义初始化,您需要在并行化之前/之后广播权重以确保它们被复制。
为了简单地使用 DTensor 布局配置 nn.Module 的输入和输出,并执行必要的布局重新分配,而不会将模块参数分配到 DTensors,以下 ParallelStyle
可以用于调用 parallelize_module
时使用的 parallelize_plan
- class torch.distributed.tensor.parallel.PrepareModuleInput(*, input_layouts=None, desired_input_layouts=None, input_kwarg_layouts=None, desired_input_kwarg_layouts=None, use_local_output=False)[source]¶
配置 nn.Module 的输入,以便根据
input_layouts
在运行时将 nn.Module 的输入张量转换为 DTensors,并根据desired_input_layouts
执行布局重新分配。- 关键字参数
input_layouts (Union[Placement, Tuple[Optional[Placement]]]) – nn.Module 输入张量的 DTensor 布局,用于将输入张量转换为 DTensors。如果某些输入不是 torch.Tensor 或不需要转换为 DTensors,则需要指定
None
作为占位符。默认:None。desired_input_layouts (Union[Placement, Tuple[Optional[Placement]]]) – nn.Module 输入张量的期望 DTensor 布局,用于确保 nn.Module 的输入具有期望的 DTensor 布局。此参数需要与
input_layouts
具有相同的长度。默认:None。input_kwarg_layouts (Dict[字符串, Placement]) – nn.Module 输入 kwargs 的 DTensor 布局,用于将输入 kwarg 张量转换为 DTensors。默认:None
desired_input_kwarg_layouts – (Dict[str, Placement]): nn.Module 输入 kwargs 的期望 DTensor 布局,用于确保 nn.Module 的输入具有期望的 DTensor 布局。默认:None。
use_local_output (布尔值, 可选) – 是否使用本地
torch.Tensor
而不是DTensor
作为模块输入,默认:False。
- 返回值
准备 nn.Module 输入的分片布局的
ParallelStyle
对象。
- 示例:
>>> from torch.distributed.tensor.parallel import parallelize_module, PrepareModuleInput >>> from torch.distributed.device_mesh import init_device_mesh >>> ... >>> block = TransformerBlock(...) # block is a nn.Module that contains an "attn" Attention submodule >>> tp_mesh = init_device_mesh("cuda", (8,)) >>> >>> # According to the style specified below, the first input of attn will be annotated to Sharded DTensor >>> # and then redistributed to Replicated DTensor. >>> parallelize_module( >>> block, # this can be a submodule or module >>> tp_mesh, >>> parallelize_plan={ >>> "attn": PrepareModuleInput( >>> input_layouts=(Shard(0), None, None, ...), >>> desired_input_layouts=(Replicate(), None, None, ...) >>> ), >>> } >>> )
- class torch.distributed.tensor.parallel.PrepareModuleOutput(*, output_layouts, desired_output_layouts, use_local_output=True)[source]¶
配置 nn.Module 的输出,以便根据
output_layouts
在运行时将 nn.Module 的输出张量转换为 DTensors,并根据desired_output_layouts
执行布局重新分配。- 关键字参数
output_layouts (Union[Placement, Tuple[Placement]]) – nn.Module 输出张量的 DTensor 布局,用于将输出张量转换为 DTensors(如果它们是
torch.Tensor
)。如果某些输出不是 torch.Tensor 或无需转换为 DTensors,则需要指定None
作为占位符。desired_output_layouts (Union[Placement, Tuple[Placement]]) – nn.Module 输出张量的预期 DTensor 布局,用于确保 nn.Module 的输出具有预期 DTensor 布局。
use_local_output (bool, optional) – 是否使用本地
torch.Tensor
而不是DTensor
作为模块输出,默认值为 True。
- 返回值
一个 ParallelStyle 对象,用于准备 nn.Module 输出的切分布局。
- 示例:
>>> from torch.distributed.tensor.parallel import parallelize_module, PrepareModuleOutput >>> from torch.distributed.device_mesh import init_device_mesh >>> ... >>> block = TransformerBlock(...) # block is a nn.Module that contains an "attn" Attention submodule >>> tp_mesh = init_device_mesh("cuda", (8,)) >>> >>> # According to the style specified below, the output of the TransformerBlock will be converted to Replicated DTensor >>> # and then redistributed to Sharded DTensor. >>> parallelize_module( >>> block, # this can be a submodule or module >>> tp_mesh, >>> parallelize_plan = PrepareModuleOutput( >>> output_layouts=Replicate(), >>> desired_output_layouts=Shard(0) >>> ) >>> )
注意
当使用 Shard(dim)
作为上述 ParallelStyle
的输入/输出布局时,我们假设输入/输出激活张量在张量维度 dim
上在 TP 操作的 DeviceMesh
上均匀切分。例如,由于 RowwiseParallel
接受在最后一个维度上切分的输入,因此它假设输入张量已经沿着最后一个维度均匀切分。对于不均匀切分的激活张量,可以将 DTensor 直接传递到分区模块,并使用 use_local_output=False
在每个 ParallelStyle
之后返回 DTensor,其中 DTensor 可以跟踪不均匀的切分信息。
对于 Transformer 等模型,我们建议用户在 parallelize_plan 中一起使用 ColwiseParallel
和 RowwiseParallel
来为整个模型(即 Attention 和 MLP)实现所需的切分。
并行化的交叉熵损失计算(损失并行),可以通过以下上下文管理器来支持
- torch.distributed.tensor.parallel.loss_parallel()[source]¶
一个上下文管理器,它可以启用损失并行,当输入在类别维度上切分时,可以执行有效的并行化损失计算。目前仅支持交叉熵损失。
在这个上下文管理器中,可以像往常一样使用
cross_entropy()
或CrossEntropyLoss
,但对输入参数有以下假设。如果存在,相应的backward()
调用也需要在这个上下文管理器中进行。- 参数
input (
DTensor
) – 输入 logits。假设在类别维度上切分。target (Union[
torch.Tensor
,DTensor
]) – 必须是真实类别索引(当前不支持类别概率)。假设在DeviceMesh
上复制。weight (Union[
torch.Tensor
,DTensor
], optional) – 如果给出,则假设在DeviceMesh
上复制。label_smoothing – 当前不支持。
- 返回值
一个复制的
DTensor
。
示例
这里手动创建了一个切分的 DTensor 来展示用法。在实际应用中,它通常是 TP 模块的输出。
>>> from torch.distributed.tensor.parallel import loss_parallel >>> from torch.distributed.device_mesh import init_device_mesh >>> ... >>> device_mesh = init_device_mesh("cuda", (8,)) >>> input = torch.randn(4, 16, device="cuda", requires_grad=True) >>> dist_input = distribute_tensor(input, device_mesh, placements=[Shard(1)]) >>> target = torch.randint(16, (4,), device="cuda") >>> with loss_parallel(): >>> loss = F.cross_entropy(dist_input, target, reduction="mean") >>> loss.backward() >>> ...
警告
loss_parallel API 处于实验阶段,可能会发生变化。