• 文档 >
  • torch.distributed.fsdp.fully_shard
快捷方式

torch.distributed.fsdp.fully_shard

PyTorch FSDP2 (fully_shard)

PyTorch FSDP2 提供全分片数据并行(FSDP)实现,以提升 performant eager-mode 的性能,同时通过按参数分片提高易用性。

  • 如果您是 FSDP 新手,我们建议您从 FSDP2 开始,因为它具有更高的易用性。

  • 如果您当前使用 FSDP1,请评估以下差异,看是否应切换到 FSDP2

与 PyTorch FSDP1 (FullyShardedDataParallel) 相比

  • FSDP2 使用基于 DTensor 的 dim-0 按参数分片,与 FSDP1 的平面参数分片相比,分片表示更简单,同时保持相似的吞吐量性能。更具体地说,FSDP2 通过 torch.chunk(dim=0) 在 dim-0 上将每个参数分块到数据并行工作节点上,而 FSDP1 则将一组张量扁平化、连接并分块在一起,使得理解每个工作节点上存在哪些数据以及重新分片到不同的并行模式变得复杂。按参数分片提供了更直观的用户体验,放宽了对冻结参数的限制,并允许使用无需通信的(分片)状态字典,这在 FSDP1 中则需要 all-gather 操作。

  • FSDP2 实现了不同的内存管理方法来处理多流使用,避免了 torch.Tensor.record_stream。这确保了确定性和预期的内存使用,并且不像 FSDP1 的 limit_all_gathers=True 那样需要阻塞 CPU。

  • FSDP2 提供了 API,允许手动控制预取和集合操作调度,为高级用户提供更多定制选项。有关详细信息,请参阅下面的 FSDPModule 方法。

  • FSDP2 简化了一些 API 表面:例如,FSDP2 不直接支持完整状态字典。用户可以使用 DTensor API(如 DTensor.full_tensor())或使用更高级别的 API(如 PyTorch Distributed Checkpoint 的分布式状态字典 API)自行将包含 DTensor 的分片状态字典重新分片为完整状态字典。此外,还移除了一些其他参数;有关详细信息,请参阅此处

如果您是首次接触 FSDP 或以上任何一点符合您的用例,我们建议您考虑使用 FSDP2。

有关系统设计和实现的详细信息,请参阅此 RFC

注意

torch.distributed.fsdp.fully_shard 目前处于原型阶段并正在开发中。核心 API 可能不会改变,但如有必要,我们可能会进行一些 API 更改。

前端 API 是 fully_shard,可以在 module 上调用

torch.distributed.fsdp.fully_shard(module, *, mesh=None, reshard_after_forward=True, shard_placement_fn=None, mp_policy=MixedPrecisionPolicy(param_dtype=None, reduce_dtype=None, output_dtype=None, cast_forward_inputs=True), offload_policy=OffloadPolicy(), ignored_params=None)[源代码]

将全分片数据并行 (FSDP) 应用于 module,其中 FSDP 在数据并行工作节点之间分片模块参数、梯度和优化器状态,以牺牲通信开销来节省内存。

初始化时,FSDP 会根据 mesh 在数据并行工作节点之间分片模块参数。在前向计算之前,FSDP 会在数据并行工作节点之间 all-gather 分片参数以获取未分片参数用于前向计算。如果 reshard_after_forwardTrue,则 FSDP 会在前向计算后释放未分片参数,并在后向计算梯度之前重新 all-gather 它们。梯度计算完成后,FSDP 会释放未分片参数,并在数据并行工作节点之间 reduce-scatter 未分片梯度。

此实现将分片参数表示为在 dim-0 上分片的 DTensor,而未分片参数将与 module 上的原始参数类似(例如,如果原始参数是 torch.Tensor,则仍是 torch.Tensor)。模块 forward pre-hookmodule 上 all-gather 参数,模块 forward hookmodule 上释放它们(如果需要)。类似的 backward hook 会 all-gather 参数,然后释放参数并 reduce-scatter 梯度。

由于将多个张量组合在一起进行一次集合操作对于通信效率至关重要,此实现将这种分组视为首要功能。在 module 上调用 fully_shard() 会构建一个组,该组包含 module.parameters() 中的参数,但那些已在子模块的早期调用中分配给其他组的参数除外。这意味着 fully_shard() 应该在模型上自底向上调用。每个组的参数在一次集合操作中 all-gather,其梯度在一次集合操作中 reduce-scatter。将模型划分为多个组(“逐层”)可以实现峰值内存节省和通信/计算重叠。用户通常不应该只在最顶层的根模块上调用 fully_shard()

参数
  • module (Union[nn.Module, List[nn.Module]) – 要使用 FSDP 分片并分组进行通信的模块或模块列表。

  • mesh (Optional[DeviceMesh]) – 此数据并行网格定义了分片和设备。如果为 1D,则参数在 1D 网格(FSDP)上使用 (Shard(0),) 放置进行全分片。如果为 2D,则参数在第 1 维上分片,并在第 0 维上复制(HSDP),使用 (Replicate(), Shard(0)) 放置。网格的设备类型指定了用于通信的设备类型;如果是 CUDA 或类似 CUDA 的设备类型,则使用当前设备。

  • reshard_after_forward (Union[bool, int]) –

    控制前向计算后的参数行为,可权衡内存和通信开销:

    • 如果为 True,则在前向计算后重新分片参数,并在后向计算中重新 all-gather。

    • 如果为 False,则在前向计算后在内存中保留未分片参数,并避免后向计算中的 all-gather。

    • 如果为 int,则表示前向计算后重新分片到的世界大小。它应该是网格分片维度大小的一个非平凡因子(即排除 1 和维度大小本身)。一个选择可以是节点内大小(例如 torch.cuda.device_count())。这允许后向计算中的 all-gather 在较小的世界大小上进行,但代价是内存使用高于设置为 True 的情况。

    • 根 FSDP 状态的值被特别设置为 False 作为启发式处理,因为其参数通常会立即进行 all-gather 以用于后向计算。

    • 前向计算后,注册到模块的参数取决于此设置:如果为 True,则注册的参数是分片参数;如果为 False,则为未分片参数;否则为重新分片到较小网格的参数。要在前向和后向计算之间修改参数,注册的参数必须是分片参数。对于 Falseint,可以通过手动调用 reshard() 进行重新分片。

  • shard_placement_fn (Optional[Callable[[nn.Parameter], Optional[Shard]]]) – 此可调用对象可用于覆盖参数的分片放置,以便在 dim-0 以外的维度上分片参数。如果此可调用对象返回 Shard 放置(而非 None),则 FSDP 将根据该放置进行分片(例如 Shard(1))。如果在非零维度上分片,我们当前要求均匀分片,即该维度上的张量维度大小必须能被 FSDP 分片网格大小整除。

  • mp_policy (MixedPrecisionPolicy) – 控制混合精度策略,为此模块提供参数/归约混合精度。有关详细信息,请参阅MixedPrecisionPolicy

  • offload_policy (OffloadPolicy) – 控制卸载策略,为此模块提供参数/梯度/优化器状态卸载。有关详细信息,请参阅OffloadPolicy 及其子类。

  • ignored_params (Optional[set[nn.Parameter]]) – 可选(Set[nn.Parameter]):不希望使用 FSDP 进行分片的参数集合。

返回值

应用了 FSDP 的模块(就地修改)。

返回类型

FSDPModule

调用 fully_shard(module) 会动态构建一个新类,该类是 type(module) 和 FSDP 类 FSDPModule 的子类。例如,如果我们在模块 linear: nn.Linear 上调用 fully_shard(linear),则 FSDP 会构建一个新类 FSDPLinear 并将 linear 的类型更改为此新类。否则,fully_shard 不会改变模块结构和参数的完全限定名称。FSDPModule 类允许在模块上提供一些 FSDP 特有的方法。

class torch.distributed.fsdp.FSDPModule(*args, **kwargs)
reshard()[源代码][源代码]

重新分片模块参数,如果分配了未分片参数,则释放它们并将分片参数注册到模块。此方法不是递归的。

set_all_reduce_hook(hook, *, stream=None)[源代码][源代码]
参数
  • hook (Callable[[torch.Tensor], None]) – 用户定义的 all-reduce 钩子,期望的签名是 hook(reduce_output: torch.Tensor) -> None,其中 reduce_output 是仅使用 FSDP 时的 reduce-scatter 输出,或使用原生 HSDP 时的 all-reduce 输出。

  • stream (Optional[torch.cuda.Stream]) – 运行 all-reduce 钩子的流。仅在不使用原生 HSDP 时应设置此参数。如果使用原生 HSDP,钩子将在原生 HSDP all-reduce 使用的内部定义的 all-reduce 流中运行。

set_is_last_backward(is_last_backward)[源代码][源代码]

设置下一个后向计算是否是最后一个。在最后一个后向计算中,FSDP 会等待未完成的梯度归约,并清除用于后向预取的内部数据结构。这对于微批处理非常有用。

set_modules_to_backward_prefetch(modules)[源代码][源代码]

设置此 FSDP 模块应在后向计算中显式预取 all-gather 的 FSDP 模块。这会覆盖根据逆前向计算后顺序预取下一个 FSDP 模块的默认后向预取实现。

传递包含前一个 FSDP 模块的单元素列表,会获得与默认重叠行为相同的 all-gather 重叠行为。传递至少包含两个元素的列表,可以实现更积极的重叠,并且会使用更多预留内存。

参数

modules (List[FSDPModule]) – 要预取的 FSDP 模块列表。

set_modules_to_forward_prefetch(modules)[源代码][源代码]

设置此 FSDP 模块应在前向计算中显式预取 all-gather 的 FSDP 模块。预取在此模块的 all-gather copy-out 后运行。

传递包含下一个 FSDP 模块的单元素列表,会获得与默认重叠行为相同的 all-gather 重叠行为,但预取的 all-gather 会从 CPU 更早发出。传递至少包含两个元素的列表,可以实现更积极的重叠,并且会使用更多预留内存。

参数

modules (List[FSDPModule]) – 要预取的 FSDP 模块列表。

set_post_optim_event(event)[源代码][源代码]

为根 FSDP 模块设置一个优化器步骤后事件,以便等待 all-gather 流。

默认情况下,根 FSDP 模块会在当前流上等待 all-gather 流,以确保优化器步骤在 all-gather 之前完成。然而,如果在优化器步骤后有不相关的计算,这可能会引入错误依赖。此 API 允许用户提供自己的事件来等待。根模块等待事件后,事件将被丢弃,因此每次迭代都应使用新事件调用此 API。

参数

event (torch.Event) – 在优化器步骤之后记录的事件,用于等待全局收集流。

set_reduce_scatter_divide_factor(factor)[source][source]

设置 reduce-scatter 的自定义除数因子。这会成为使用 NCCL 的 PreMulSum 的自定义规约操作,允许在规约前乘以该因子。

参数

factor (浮点型) – 自定义除数因子。

set_requires_all_reduce(requires_all_reduce, *, recurse=True)[source][source]

设置模块是否应该对梯度进行 all-reduce(全局规约)。这可用于实现仅使用 reduce-scatter 而不使用 all-reduce 的 HSDP 梯度累积。

set_requires_gradient_sync(requires_gradient_sync, *, recurse=True)[source][source]

设置模块是否应该同步梯度。这可用于实现*不进行通信*的梯度累积。对于 HSDP,这同时控制 reduce-scatter 和 all-reduce。这相当于 FSDP1 中的 no_sync

参数
  • requires_gradient_sync (布尔类型) – 是否对模块的参数进行梯度规约。

  • recurse (布尔类型) – 是否为所有 FSDP 子模块设置,还是仅为传入的模块设置。

set_reshard_after_backward(reshard_after_backward, *, recurse=True)[source][source]

设置模块是否应该在反向传播后重新分片参数。这可在梯度累积期间使用,以权衡更高的内存消耗来减少通信,因为未分片的参数在下一次前向传播前无需重新全局收集。

参数
  • reshard_after_backward (布尔类型) – 是否在反向传播后重新分片参数。

  • recurse (布尔类型) – 是否为所有 FSDP 子模块设置,还是仅为传入的模块设置。

set_unshard_in_backward(unshard_in_backward)[source][source]

设置 FSDP 模块的参数在反向传播中是否需要取消分片。这可在专家场景中使用,当用户知道此 FSDP 模块参数组中的所有参数都不需要用于反向计算时(例如,嵌入层)。

unshard(async_op=False)[source][source]

通过分配内存和全局收集参数来取消模块的参数分片。此方法*不是*递归的。取消分片遵循 MixedPrecisionPolicy,因此如果设置了 param_dtype,它将根据 param_dtype 进行全局收集。

参数

async_op (布尔类型) – 如果为 True,则返回一个 UnshardHandle,它具有 wait() 方法来等待取消分片操作。如果为 False,则返回 None 并在函数内部等待 handle。

返回类型

可选[UnshardHandle]

注意

如果 async_op=True,则 FSDP 将在模块的前向传播前(pre-forward)为用户等待待处理的取消分片操作。用户仅需在等待必须发生在前向传播前(pre-forward)时才显式调用 wait()

class torch.distributed.fsdp.UnshardHandle

用于等待 FSDPModule.unshard() 操作的 handle。

wait()[source][source]

等待取消分片操作。这确保当前流可以使用已取消分片的参数,这些参数现已注册到模块。

torch.distributed.fsdp.register_fsdp_forward_method(module, method_name)[source]

module 上注册一个方法,使其被视为 FSDP 的前向方法。

FSDP 在前向传播前(pre-forward)全局收集参数,并可选地在后向传播后(post-forward)释放参数(取决于 reshard_after_forward)。默认情况下,FSDP 只知道对 nn.Module.forward() 执行此操作。此函数修补一个用户指定的方法,使其分别在该方法之前/之后运行前向传播前/后(pre/post-forward)钩子。如果 module 不是 FSDPModule,则此操作无效(no-op)。

参数
  • module (nn.Module) – 用于注册前向方法的模块。

  • method_name (字符串类型) – 前向方法的名称。

class torch.distributed.fsdp.MixedPrecisionPolicy(param_dtype=None, reduce_dtype=None, output_dtype=None, cast_forward_inputs=True)

这配置了 FSDP 的混合精度。与 autocast 不同,它在模块级别而不是操作级别应用混合精度,这意味着低精度激活会为反向传播保留,并且高精度到低精度的转换仅发生在模块边界。

FSDP 与模块级混合精度配合良好,因为它无论如何都会在内存中保留高精度的分片参数。换句话说,FSDP 不需要额外的内存来为优化器步骤保留参数的高精度副本。

变量
  • param_dtype (Optional[torch.dtype]) – 这指定了未分片参数的 dtype,因此也指定了前向/反向计算以及参数全局收集的 dtype。如果为 None,则未分片参数使用原始 dtype。优化器步骤使用原始 dtype 中的分片参数。(默认值:None

  • reduce_dtype (Optional[torch.dtype]) – 这指定了梯度规约的 dtype(即 reduce-scatter 或 all-reduce)。如果为 Noneparam_dtype 不为 None,则规约使用计算 dtype。这可用于在全精度下运行梯度规约,同时使用低精度进行计算。如果通过 set_requires_gradient_sync() 也禁用了梯度规约,则 FSDP 将使用 reduce_dtype 累积梯度。(默认值:None

  • output_dtype (Optional[torch.dtype]) – 这指定了浮点前向输出的转换 dtype。这可用于帮助实现不同模块具有不同混合精度策略的场景。(默认值:None

  • cast_forward_inputs (布尔类型) – 这指定 FSDP 是否应将前向的浮点输入张量转换为 param_dtype

class torch.distributed.fsdp.OffloadPolicy

此基类表示不进行卸载的策略,仅用作 offload_policy 参数的默认值。

class torch.distributed.fsdp.CPUOffloadPolicy(pin_memory=True)

此卸载策略将参数、梯度和优化器状态卸载到 CPU。分片参数在全局收集前从主机复制到设备。全局收集的参数根据 reshard_after_forward 释放。分片梯度在反向传播中从设备复制到主机,优化器步骤在 CPU 上使用 CPU 优化器状态运行。

变量

pin_memory (布尔类型) – 是否锁定分片参数和梯度内存。锁定内存可以提高 H2D/D2H 复制效率,并使复制与计算重叠。但是,其他进程无法使用锁定内存。如果 CPU 内存不足,请将此项设置为 False。(默认值:True

文档

访问 PyTorch 的完整开发者文档

查看文档

教程

获取针对初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题解答

查看资源