• 文档 >
  • torch.distributed.fsdp.fully_shard
快捷方式

torch.distributed.fsdp.fully_shard

PyTorch FSDP2 (fully_shard)

PyTorch FSDP2 提供完全分片数据并行 (FSDP) 实现,目标是在使用每参数分片以提高可用性的同时实现高性能的 eager 模式。

  • 如果您是 FSDP 的新手,我们建议您从 FSDP2 开始,因为它具有更高的可用性。

  • 如果您目前正在使用 FSDP1,请考虑评估以下差异,以确定是否应该切换到 FSDP2

与 PyTorch FSDP1 (FullyShardedDataParallel) 相比

  • FSDP2 使用基于 DTensor 的 dim-0 每参数分片,以获得比 FSDP1 的平面参数分片更简单的分片表示,同时保持相似的吞吐量性能。更具体地说,FSDP2 在数据并行工作进程之间按 dim-0 对每个参数进行分块(使用 torch.chunk(dim=0)),而 FSDP1 则展平、连接并将一组张量分块在一起,这使得推理每个工作进程上存在哪些数据以及重新分片到不同的并行性变得复杂。每参数分片提供了更直观的用户体验,放宽了对冻结参数的约束,并允许无通信(分片)状态字典,否则在 FSDP1 中需要 all-gather。

  • FSDP2 实现了不同的内存管理方法来处理多流使用,避免了 torch.Tensor.record_stream。这确保了确定性和预期的内存使用,并且不需要像 FSDP1 的 limit_all_gathers=True 中那样阻塞 CPU。

  • FSDP2 公开了用于手动控制预取和集体调度的 API,从而为高级用户提供了更多自定义选项。有关详细信息,请参阅下面的 FSDPModule 上的方法。

  • FSDP2 简化了一些 API 表面:例如,FSDP2 不直接支持完整状态字典。相反,用户可以使用 DTensor API(如 DTensor.full_tensor())或使用更高级别的 API(如 PyTorch 分布式检查点 的分布式状态字典 API)将包含 DTensor 的分片状态字典重新分片为完整状态字典。此外,还删除了一些其他参数;有关详细信息,请参阅此处

如果您是首次使用 FSDP,或者上述任何一项对您的用例有吸引力,我们建议您考虑使用 FSDP2。

有关系统设计和实现的详细信息,请参阅 此 RFC

注意

torch.distributed.fsdp.fully_shard 目前处于原型状态,正在开发中。核心 API 可能不会更改,但我们可能会在必要时进行一些 API 更改。

前端 API 是 fully_shard,可以在 module 上调用

torch.distributed.fsdp.fully_shard(module, *, mesh=None, reshard_after_forward=True, shard_placement_fn=None, mp_policy=MixedPrecisionPolicy(param_dtype=None, reduce_dtype=None, output_dtype=None, cast_forward_inputs=True), offload_policy=OffloadPolicy())[source]

将完全分片数据并行 (FSDP) 应用于 module,其中 FSDP 在数据并行工作进程之间分片模块参数、梯度和优化器状态,以节省内存,但会增加通信成本。

在初始化时,FSDP 根据 mesh 给定的数据并行工作进程分片模块的参数。在正向传播之前,FSDP all-gather 数据并行工作进程之间的分片参数,以获取用于正向计算的未分片参数。如果 reshard_after_forwardTrue,则 FSDP 在正向传播后释放未分片参数,并在反向传播中重新 all-gather 它们,然后再进行梯度计算。在梯度计算之后,FSDP 释放未分片参数,并将未分片梯度 reduce-scatter 到数据并行工作进程中。

此实现将分片参数表示为在 dim-0 上分片的 DTensor,而未分片参数将类似于 module 上的原始参数(例如,如果最初是 torch.Tensor,则为 torch.Tensor)。module 上的模块 正向传播预钩子 all-gather 参数,module 上的模块 正向传播钩子 释放它们(如果需要)。类似的后向钩子 all-gather 参数,稍后释放参数并 reduce-scatter 梯度。

由于将多个张量分组在一起进行一次集体通信对于通信效率至关重要,因此此实现将此分组作为首要考虑因素。在 module 上调用 fully_shard() 会构造一个组,其中包含 module.parameters() 中的参数,但已从先前对子模块的调用中分配给组的参数除外。这意味着应该在模型的自下而上的方向调用 fully_shard()。每个组的参数都在一个集体通信中 all-gather,其梯度在一个集体通信中 reduce-scatter。将模型划分为多个组(“逐层”)可以节省峰值内存并实现通信/计算重叠。用户通常不应仅在最顶层的根模块上调用 fully_shard()

参数
  • module (Union[nn.Module, List[nn.Module]) – 要使用 FSDP 分片并分组在一起进行通信的模块或多个模块。

  • mesh (Optional[DeviceMesh]) – 此数据并行网格定义了分片和设备。如果为 1D,则参数在 1D 网格 (FSDP) 上完全分片,放置为 (Shard(0),)。如果为 2D,则参数在第 1 维上分片,并在第 0 维上复制 (HSDP),放置为 (Replicate(), Shard(0))。网格的设备类型给出了用于通信的设备类型;如果是 CUDA 或类 CUDA 设备类型,则我们使用当前设备。

  • reshard_after_forward (Union[bool, int]) –

    这控制了正向传播后的参数行为,并且可以在内存和通信之间进行权衡

    • 如果为 True,则在正向传播后重新分片参数,并在反向传播中重新 all-gather。

    • 如果为 False,则在正向传播后将未分片参数保留在内存中,并避免在反向传播中进行 all-gather。

    • 如果为 int,则表示正向传播后重新分片到的世界大小。它应该是 mesh 分片维度大小的非平凡约数(即,排除 1 和维度大小本身)。一个选择可以是节点内大小(例如 torch.cuda.device_count())。与设置为 True 相比,这允许在较小的世界大小上进行反向传播中的 all-gather,但代价是更高的内存使用率。

    • 根 FSDP 状态的值已特别设置为 False 作为启发式方法,因为其参数通常会立即 all-gather 以进行反向传播。

    • 正向传播后,注册到模块的参数取决于此:如果为 True,则注册的参数是分片参数;如果为 False,则为未分片参数;否则为重新分片到较小网格的参数。要在正向传播和反向传播之间修改参数,则注册的参数必须是分片参数。对于 Falseint,可以通过使用 reshard() 手动重新分片来完成此操作。

  • shard_placement_fn (Optional[Callable[[nn.Parameter], Optional[Shard]]]) – 此可调用对象可用于覆盖参数的分片放置,以在维度(而非 dim-0)上分片参数。如果此可调用对象返回 Shard 放置(而非 None),则 FSDP 将根据该放置进行分片(例如 Shard(1))。如果在非零维度上分片,我们目前要求均匀分片,即该维度上的张量维度大小必须可被 FSDP 分片网格大小整除。

  • mp_policy (MixedPrecisionPolicy) – 这控制混合精度策略,该策略为此模块提供参数/缩减混合精度。有关详细信息,请参阅 MixedPrecisionPolicy

  • offload_policy (OffloadPolicy) – 这控制卸载策略,该策略提供参数/梯度/优化器状态卸载。有关详细信息,请参阅 OffloadPolicy 及其子类。

调用 fully_shard(module) 会动态构造一个新类,该类继承自 type(module) 和 FSDP 类 FSDPModule。例如,如果我们对模块 linear: nn.Linear 调用 fully_shard(linear),则 FSDP 会构造一个新类 FSDPLinear 并将 linear 的类型更改为此类。否则,fully_shard 不会更改模块结构和参数完全限定名称。类 FSDPModule 允许在模块上提供一些 FSDP 特定的方法。

class torch.distributed.fsdp.FSDPModule(*args, **kwargs)
reshard()[source][source]

重新分片模块的参数,如果已分配,则释放未分片参数,并将分片参数注册到模块。此方法是递归的。

set_is_last_backward(is_last_backward)[source][source]

设置下一个反向传播是否为最后一个。在最后一个反向传播中,FSDP 等待挂起的梯度缩减,并清除用于反向传播预取的内部数据结构。这对于微批处理很有用。

set_modules_to_backward_prefetch(modules)[source][source]

设置此 FSDP 模块应在反向传播中显式预取 all-gather 的 FSDP 模块。这将覆盖默认的反向传播预取实现,该实现基于反向后向传播顺序预取下一个 FSDP 模块。

传递包含先前 FSDP 模块的单例列表会给出与默认重叠行为相同的 all-gather 重叠行为。要实现更积极的重叠并使用更多保留内存,需要传递至少长度为 2 的列表。

参数

modules (List[FSDPModule]) – 要预取的 FSDP 模块。

set_modules_to_forward_prefetch(modules)[source][source]

设置此 FSDP 模块应在正向传播中显式预取 all-gather 的 FSDP 模块。预取在此模块的 all-gather 复制输出后运行。

传递包含下一个 FSDP 模块的单例列表会给出与默认重叠行为相同的 all-gather 重叠行为,不同之处在于预取的 all-gather 更早地从 CPU 发出。要实现更积极的重叠并使用更多保留内存,需要传递至少长度为 2 的列表。

参数

modules (List[FSDPModule]) – 要预取的 FSDP 模块。

set_post_optim_event(event)[source][source]

为根 FSDP 模块设置一个优化器步骤后事件,以等待 all-gather 流完成。

默认情况下,根 FSDP 模块在当前流上等待 all-gather 流,以确保优化器步骤在 all-gather 之前已完成。但是,如果优化器步骤之后存在不相关的计算,则这可能会引入虚假的依赖关系。此 API 允许用户提供自己的事件来等待。在根模块等待事件后,该事件将被丢弃,因此应在每次迭代时使用新事件调用此 API。

参数

event (torch.Event) – 在优化器步骤后记录的事件,用于等待 all-gather 流完成。

set_reduce_scatter_divide_factor(factor)[source][source]

为 reduce-scatter 设置自定义除法因子。这成为使用 NCCL 的 PreMulSum 的自定义 reduce 运算,它允许在缩减之前乘以因子。

参数

factor (float) – 自定义除法因子。

set_requires_all_reduce(requires_all_reduce, *, recurse=True)[source][source]

设置模块是否应 all-reduce 梯度。这可以用于实现梯度累积,对于 HSDP 仅使用 reduce-scatter 而不使用 all-reduce。

set_requires_gradient_sync(requires_gradient_sync, *, recurse=True)[source][source]

设置模块是否应同步梯度。这可以用于实现无通信的梯度累积。对于 HSDP,这同时控制 reduce-scatter 和 all-reduce。

参数
  • requires_gradient_sync (bool) – 是否为模块的参数减少梯度同步。

  • recurse (bool) – 是否为所有 FSDP 子模块设置,或者仅为传入的模块设置。

set_reshard_after_backward(reshard_after_backward, *, recurse=True)[source][source]

设置模块是否应在反向传播后重新分片参数。这可以在梯度累积期间使用,以权衡更高的内存占用来减少通信,因为在下一次前向传播之前,未分片的参数不需要重新全收集。

参数
  • reshard_after_backward (bool) – 是否在反向传播后重新分片参数。

  • recurse (bool) – 是否为所有 FSDP 子模块设置,或者仅为传入的模块设置。

set_unshard_in_backward(unshard_in_backward)[source][source]

设置 FSDP 模块的参数是否需要在反向传播中取消分片。当用户知道此 FSDP 模块的参数组中的所有参数都不需要用于反向计算时(例如,嵌入),可以在专家情况下使用此功能。

unshard(async_op=False)[source][source]

通过分配内存和全收集参数来取消模块的参数分片。此方法是递归的。取消分片遵循 MixedPrecisionPolicy,因此如果设置了 param_dtype,它将按照 param_dtype 进行全收集。

参数

async_op (bool) – 如果 True,则返回一个 UnshardHandle,该句柄具有一个 wait() 方法来等待取消分片操作。如果 False,则返回 None 并在该函数内部等待句柄。

返回类型

Optional[UnshardHandle]

注意

如果 async_op=True,则 FSDP 将在模块的预前向传播中等待待处理的取消分片。用户仅需要在等待应该发生在预前向传播之前时显式调用 wait()

class torch.distributed.fsdp.UnshardHandle

用于等待 FSDPModule.unshard() 操作的句柄。

wait()[source][source]

等待取消分片操作。这确保当前流可以使用未分片的参数,这些参数现在已注册到模块。

torch.distributed.fsdp.register_fsdp_forward_method(module, method_name)[source]

module 上注册一个方法,使其被视为 FSDP 的前向方法。

FSDP 在前向传播前全收集参数,并可选择在前向传播后释放参数(取决于 reshard_after_forward)。默认情况下,FSDP 只知道为 nn.Module.forward() 执行此操作。此函数修补用户指定的方法,以分别在方法之前/之后运行前/后向传播钩子。如果 module 不是 FSDPModule,则这是一个空操作。

参数
  • module (nn.Module) – 要在其上注册前向方法的模块。

  • method_name (str) – 前向方法的名称。

class torch.distributed.fsdp.MixedPrecisionPolicy(param_dtype=None, reduce_dtype=None, output_dtype=None, cast_forward_inputs=True)

这配置了 FSDP 的混合精度。与自动类型转换不同,这在模块级别而不是操作级别应用混合精度,这意味着低精度激活值被保存用于反向传播,并且高精度到低精度的类型转换仅在模块边界发生。

FSDP 与模块级混合精度配合良好,因为它无论如何都会在内存中保留高精度的分片参数。换句话说,FSDP 不需要任何额外的内存来为优化器步骤保留参数的高精度副本。

变量
  • param_dtype (Optional[torch.dtype]) – 这指定了未分片参数的数据类型,因此也是前向/反向传播计算和参数全收集的数据类型。如果为 None,则未分片参数使用原始数据类型。优化器步骤使用原始数据类型的分片参数。(默认值:None

  • reduce_dtype (Optional[torch.dtype]) – 这指定了梯度归约的数据类型(即,归约-分散或全归约)。如果为 Noneparam_dtype 不为 None,则归约使用计算数据类型。这可以用于在梯度归约中使用全精度,同时在计算中使用低精度。如果还通过 set_requires_gradient_sync() 禁用梯度归约,则 FSDP 将使用 reduce_dtype 累积梯度。(默认值:None

  • output_dtype (Optional[torch.dtype]) – 这指定了用于转换浮点前向输出的数据类型。这可以用于帮助实现不同模块具有不同混合精度策略的情况。(默认值:None

  • cast_forward_inputs (bool) – 这指定了 FSDP 是否应将前向传播的浮点输入张量转换为 param_dtype

class torch.distributed.fsdp.OffloadPolicy

此类是无卸载策略的基类,仅用作 offload_policy 参数的默认值。

class torch.distributed.fsdp.CPUOffloadPolicy(pin_memory=True)

此卸载策略将参数、梯度和优化器状态卸载到 CPU。分片参数在全收集之前被复制到设备。全收集的参数根据 reshard_after_forward 释放。分片梯度在反向传播中从设备复制到主机,优化器步骤在 CPU 上使用 CPU 优化器状态运行。

变量

pin_memory (bool) – 是否锁定分片参数和梯度内存。锁定内存可以更有效地进行 H2D/D2H 复制,并允许复制与计算重叠。但是,锁定的内存不能被其他进程使用。如果您没有足够的 CPU 内存,请将其设置为 False。(默认值:True

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得您的问题解答

查看资源