• 文档 >
  • FullyShardedDataParallel
快捷方式

FullyShardedDataParallel

torch.distributed.fsdp.FullyShardedDataParallel(module, process_group=None, sharding_strategy=None, cpu_offload=None, auto_wrap_policy=None, backward_prefetch=BackwardPrefetch.BACKWARD_PRE, mixed_precision=None, ignored_modules=None, param_init_fn=None, device_id=None, sync_module_states=False, forward_prefetch=False, limit_all_gathers=True, use_orig_params=False, ignored_states=None, device_mesh=None)[source][source]

一个用于在数据并行工作进程中分片模块参数的包装器。

这受到 Xu 等人的工作以及 DeepSpeed 的 ZeRO Stage 3 的启发。FullyShardedDataParallel 通常简称为 FSDP。

要理解 FSDP 内部工作原理,请参考 FSDP 须知

示例

>>> import torch
>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
>>> torch.cuda.set_device(device_id)
>>> sharded_module = FSDP(my_module)
>>> optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001)
>>> x = sharded_module(x, y=3, z=torch.Tensor([1]))
>>> loss = x.sum()
>>> loss.backward()
>>> optim.step()

使用 FSDP 包括包装您的模块,然后初始化您的优化器。这是必需的,因为 FSDP 会更改参数变量。

设置 FSDP 时,需要考虑目标 CUDA 设备。如果设备有 ID (dev_id),您有三个选项:

  • 将模块放置在该设备上

  • 使用 torch.cuda.set_device(dev_id) 设置设备

  • dev_id 传入 device_id 构造函数参数。

这确保了 FSDP 实例的计算设备是目标设备。对于选项 1 和 3,FSDP 初始化总是发生在 GPU 上。对于选项 2,FSDP 初始化发生在模块当前所在的设备上,这可能是一个 CPU。

如果您使用 sync_module_states=True 标志,您需要确保模块在 GPU 上,或者使用 device_id 参数指定 FSDP 在其构造函数中会将模块移至的 CUDA 设备。这是必要的,因为 sync_module_states=True 需要 GPU 通信。

FSDP 还会处理将输入张量移动到前向方法的 GPU 计算设备上,因此您无需手动将它们从 CPU 移动。

对于 use_orig_params=TrueShardingStrategy.SHARD_GRAD_OP 暴露的是未分片的参数,而不是前向传递后的分片参数,这与 ShardingStrategy.FULL_SHARD 不同。如果您想检查梯度,可以使用 summon_full_params 方法并设置 with_grads=True

使用 limit_all_gathers=True 时,您可能会在 FSDP 前向预处理中看到 CPU 线程未发出任何 kernel 的间隙。这是有意为之,表明速率限制器正在生效。通过这种方式同步 CPU 线程可以防止为后续的 all-gather 操作过度分配内存,并且实际上不应延迟 GPU kernel 的执行。

出于与 autograd 相关的原因,FSDP 在前向和后向计算期间将受管模块的参数替换为 torch.Tensor 视图。如果您的模块的前向传递依赖于保存的参数引用,而不是在每次迭代时重新获取引用,那么它将看不到 FSDP 新创建的视图,并且 autograd 将无法正常工作。

最后,当使用 sharding_strategy=ShardingStrategy.HYBRID_SHARD 且分片进程组为节点内,复制进程组为节点间时,设置 NCCL_CROSS_NIC=1 可以帮助在某些集群设置下改善复制进程组上的 all-reduce 时间。

限制

使用 FSDP 时需要注意一些限制:

  • 当前使用 CPU 卸载时,FSDP 不支持在 no_sync() 之外进行梯度累积。这是因为 FSDP 使用新 reduce 后的梯度,而不是与任何现有梯度进行累积,这可能导致结果不正确。

  • FSDP 不支持运行包含在 FSDP 实例中的子模块的前向传递。这是因为子模块的参数会被分片,但子模块本身不是 FSDP 实例,因此其前向传递无法正确地 all-gather 完整的参数。

  • 由于 FSDP 注册后向钩子的方式,它不支持双重后向(double backwards)。

  • FSDP 在冻结参数时有一些限制。对于 use_orig_params=False,每个 FSDP 实例必须管理全部冻结或全部未冻结的参数。对于 use_orig_params=True,FSDP 支持混合冻结和未冻结参数,但建议避免这样做,以防止梯度内存使用量高于预期。

  • 截至 PyTorch 1.12,FSDP 对共享参数的支持有限。如果您的用例需要增强的共享参数支持,请在此 issue 中发表意见。

  • 您应该避免在不使用 summon_full_params 上下文的情况下在前向和后向之间修改参数,因为这些修改可能不会持久化。

参数
  • module (nn.Module) – 这是要使用 FSDP 包装的模块。

  • process_group (可选[Union[ProcessGroup, Tuple[ProcessGroup, ProcessGroup]]]) – 这是用于分片模型并因此用于 FSDP 的 all-gather 和 reduce-scatter 集体通信的进程组。如果为 None,则 FSDP 使用默认进程组。对于混合分片策略(例如 ShardingStrategy.HYBRID_SHARD),用户可以传入一个进程组的元组,分别表示用于分片和复制的组。如果为 None,则 FSDP 会为用户构建进程组,用于节点内分片和节点间复制。(默认值: None)

  • sharding_strategy (可选[ShardingStrategy]) – 这配置了分片策略,它可能会在内存节省和通信开销之间进行权衡。详见 ShardingStrategy。(默认值: FULL_SHARD)

  • cpu_offload (可选[CPUOffload]) – 这配置了 CPU 卸载。如果设置为 None,则不进行 CPU 卸载。详见 CPUOffload。(默认值: None)

  • auto_wrap_policy (可选[Union[Callable[[nn.Module, bool, int], bool], ModuleWrapPolicy, CustomPolicy]]) –

    这指定了一个策略,用于将 FSDP 应用到 module 的子模块上,这对于通信和计算重叠是必需的,从而影响性能。如果为 None,则 FSDP 仅应用于 module 本身,用户应手动(自底向上)将 FSDP 应用于父模块。为了方便,此参数直接接受 ModuleWrapPolicy,允许用户指定要包装的模块类(例如 transformer 块)。否则,这应该是一个可调用对象 (callable),接受三个参数:module: nn.Modulerecurse: boolnonwrapped_numel: int,并返回一个 bool 值,指定在 recurse=False 时是否应将 FSDP 应用于传入的 module,或者在 recurse=True 时是否应继续遍历模块的子树。用户可以向可调用对象添加额外的参数。torch.distributed.fsdp.wrap.py 中的 size_based_auto_wrap_policy 提供了一个示例可调用对象,如果模块子树中的参数数量超过 1 亿 (100M),则将其应用于模块。我们建议在应用 FSDP 后打印模型并根据需要进行调整。

    示例

    >>> def custom_auto_wrap_policy(
    >>>     module: nn.Module,
    >>>     recurse: bool,
    >>>     nonwrapped_numel: int,
    >>>     # Additional custom arguments
    >>>     min_num_params: int = int(1e8),
    >>> ) -> bool:
    >>>     return nonwrapped_numel >= min_num_params
    >>> # Configure a custom `min_num_params`
    >>> my_auto_wrap_policy = functools.partial(custom_auto_wrap_policy, min_num_params=int(1e5))
    

  • backward_prefetch (可选[BackwardPrefetch]) – 这配置了显式的后向 all-gather 预取。如果为 None,则 FSDP 不进行后向预取,并且在后向传递中没有通信和计算重叠。详见 BackwardPrefetch。(默认值: BACKWARD_PRE)

  • mixed_precision (可选[MixedPrecision]) – 这配置了 FSDP 的原生混合精度。如果设置为 None,则不使用混合精度。否则,可以设置参数、缓冲区和梯度 reduce 的 dtype。详见 MixedPrecision。(默认值: None)

  • ignored_modules (可选[Iterable[torch.nn.Module]]) – 此实例忽略的模块,包括其自身的参数以及子模块的参数和缓冲区。ignored_modules 中直接包含的模块不应是 FullyShardedDataParallel 实例,如果子模块已经是构建好的 FullyShardedDataParallel 实例并嵌套在此实例下,则不会被忽略。在使用 auto_wrap_policy 时,或者当参数的分片不由 FSDP 管理时,此参数可用于避免在模块粒度上分片特定参数。(默认值: None)

  • param_init_fn (可选[Callable[[nn.Module], None]]) –

    一个 Callable[torch.nn.Module] -> None,指定当前在 meta 设备上的模块应如何初始化到实际设备上。从 v1.12 开始,FSDP 通过 is_meta 检测参数或缓冲区在 meta 设备上的模块,并在指定了 param_init_fn 时应用它,否则调用 nn.Module.reset_parameters()。对于这两种情况,实现应**仅**初始化模块自身的参数/缓冲区,而不初始化其子模块的。这是为了避免重复初始化。此外,FSDP 还支持通过 torchdistX (https://github.com/pytorch/torchdistX) 的 deferred_init() API 进行延迟初始化,其中延迟的模块在指定了 param_init_fn 时通过调用它进行初始化,否则调用 torchdistX 的默认 materialize_module()。如果指定了 param_init_fn,则它将应用于所有 meta 设备模块,这意味着它可能需要根据模块类型进行处理。FSDP 在参数展平 (flattening) 和分片之前调用初始化函数。

    示例

    >>> module = MyModule(device="meta")
    >>> def my_init_fn(module: nn.Module):
    >>>     # E.g. initialize depending on the module type
    >>>     ...
    >>> fsdp_model = FSDP(module, param_init_fn=my_init_fn, auto_wrap_policy=size_based_auto_wrap_policy)
    >>> print(next(fsdp_model.parameters()).device) # current CUDA device
    >>> # With torchdistX
    >>> module = deferred_init.deferred_init(MyModule, device="cuda")
    >>> # Will initialize via deferred_init.materialize_module().
    >>> fsdp_model = FSDP(module, auto_wrap_policy=size_based_auto_wrap_policy)
    

  • device_id (可选[Union[int, torch.device]]) – 一个 inttorch.device,指定 FSDP 初始化(包括模块初始化(如果需要)和参数分片)所在的 CUDA 设备。如果 module 在 CPU 上,应指定此参数以提高初始化速度。如果已设置默认 CUDA 设备(例如通过 torch.cuda.set_device),则用户可以将 torch.cuda.current_device 传递给此参数。(默认值: None)

  • sync_module_states (bool) – 如果为 True,则每个 FSDP 模块将从 rank 0 广播模块参数和缓冲区,以确保它们在各个 rank 之间复制(这会增加构造函数的通信开销)。这有助于通过 load_state_dict 以内存高效的方式加载 state_dict 检查点。详见 FullStateDictConfig 以获取示例。(默认值: False)

  • forward_prefetch (bool) – 如果为 True,则 FSDP 会在当前前向计算之前显式预取下一个前向传递的 all-gather 操作。这仅对 CPU 密集型工作负载有用,在这种情况下,提前发出下一个 all-gather 可能会改善重叠。此参数仅适用于静态图模型,因为预取遵循第一次迭代的执行顺序。(默认值: False)

  • limit_all_gathers (bool) – 如果为 True,则 FSDP 会显式同步 CPU 线程,以确保 GPU 内存仅由两个连续的 FSDP 实例(当前正在运行计算的实例和下一个已预取 all-gather 的实例)使用。如果为 False,则 FSDP 允许 CPU 线程发出 all-gather 操作而无需任何额外的同步。(默认值: True) 我们通常将此功能称为“速率限制器”。只有在特定 CPU 密集型且内存压力较低的工作负载下,此标志才应设置为 False,在这种情况下,CPU 线程可以积极地发出所有 kernel,而无需担心 GPU 内存的使用。

  • use_orig_params (bool) – 将此设置为 True 会使 FSDP 使用 module 的原始参数。FSDP 通过 nn.Module.named_parameters() 向用户暴露这些原始参数,而不是 FSDP 内部的 FlatParameter。这意味着优化器步骤在原始参数上运行,从而支持按原始参数设置超参数。FSDP 保留原始参数变量,并在未分片和分片形式之间操作它们的数据,它们始终是底层未分片或分片 FlatParameter 的视图。使用当前算法时,分片形式总是 1D 的,丢失了原始的张量结构。原始参数对于给定的 rank 可能包含其全部、部分或不包含任何数据。在不包含数据的情况下,其数据将是一个大小为 0 的空张量。用户不应编写依赖于给定原始参数在分片形式中包含哪些数据的程序。使用 torch.compile() 时必须设置为 True。将此设置为 False 会通过 nn.Module.named_parameters() 向用户暴露 FSDP 内部的 FlatParameter。(默认值: False)

  • ignored_states (可选[Iterable[torch.nn.Parameter]], 可选[Iterable[torch.nn.Module]]) – 此 FSDP 实例不管理的被忽略参数或模块,这意味着这些参数不会被分片,其梯度也不会在各个 rank 之间 reduce。此参数与现有的 ignored_modules 参数功能一致,我们可能很快会弃用 ignored_modules。为了向后兼容,我们保留了 ignored_statesignored_modules`,但 FSDP 只允许其中一个被指定为非 None

  • device_mesh (可选[DeviceMesh]) – DeviceMesh 可以用作 process_group 的替代方案。当传入 device_mesh 时,FSDP 将使用底层进程组进行 all-gather 和 reduce-scatter 集体通信。因此,这两个参数需要互斥。对于混合分片策略(例如 ShardingStrategy.HYBRID_SHARD),用户可以传入 2D DeviceMesh 而不是进程组的元组。对于 2D FSDP + TP,用户必须传入 device_mesh 而不是 process_group。有关 DeviceMesh 的更多信息,请访问:https://pytorch.ac.cn/tutorials/recipes/distributed_device_mesh.html

apply(fn)[source][source]

递归地将 fn 应用于每个子模块(由 .children() 返回)以及自身。

典型用途包括初始化模型的参数(另请参见 torch.nn.init)。

torch.nn.Module.apply 相比,此版本在应用 fn 之前还会额外收集完整的参数。不应在另一个 summon_full_params 上下文中使用此方法。

参数

fn (Module -> None) – 应用于每个子模块的函数

返回

self

返回类型

Module

check_is_root()[source][source]

检查此实例是否为根 FSDP 模块。

返回类型

bool

clip_grad_norm_(max_norm, norm_type=2.0)[source][source]

裁剪所有参数的梯度范数。

范数是在将所有参数的梯度视为单个向量后计算的,并且梯度会就地修改。

参数
  • max_norm (floatint) – 梯度的最大范数

  • norm_type (floatint) – 所使用的 p-范数类型。对于无穷范数可以是 'inf'

返回

参数的总范数(视为单个向量)。

返回类型

Tensor

如果每个 FSDP 实例都使用 NO_SHARD,意味着没有梯度在各个 rank 之间分片,则您可以直接使用 torch.nn.utils.clip_grad_norm_()

如果至少有一些 FSDP 实例使用分片策略(即除 NO_SHARD 之外的策略),那么您应该使用此方法而不是 torch.nn.utils.clip_grad_norm_(),因为此方法处理了梯度在不同 rank 之间分片的情况。

返回的总范数将具有 PyTorch 类型提升语义所定义的,所有参数/梯度中“最大”的数据类型(dtype)。例如,如果所有参数/梯度都使用低精度数据类型(dtype),那么返回的范数的数据类型(dtype)将是该低精度数据类型(dtype);但如果存在至少一个参数/梯度使用 FP32,那么返回的范数的数据类型(dtype)将是 FP32。

警告

此方法需要在所有 rank 上调用,因为它使用了集合通信。

static flatten_sharded_optim_state_dict(sharded_optim_state_dict, model, optim)[source][source]

展平一个分片的优化器状态字典。

该 API 与 shard_full_optim_state_dict() 类似。唯一的区别是输入参数 sharded_optim_state_dict 应该由 sharded_optim_state_dict() 返回。因此,每个 rank 上都会有 all-gather 调用来收集 ShardedTensor

参数
返回

请参考 shard_full_optim_state_dict()

返回类型

dict[str, Any]

forward(*args, **kwargs)[source][source]

运行包装模块的前向传播,插入 FSDP 特定的前向传播之前和之后的分片逻辑。

返回类型

Any

static fsdp_modules(module, root_only=False)[source][source]

返回所有嵌套的 FSDP 实例。

这可能包含 module 本身,并且仅当 root_only=True 时包含 FSDP 根模块。

参数
  • module (torch.nn.Module) – 根模块,它可能是 FSDP 模块,也可能不是。

  • root_only (bool) – 是否仅返回 FSDP 根模块。(默认值: False)

返回

嵌套在输入 module 中的 FSDP 模块。

返回类型

List[FullyShardedDataParallel]

static full_optim_state_dict(model, optim, optim_input=None, rank0_only=True, group=None)[source][source]

返回完整的优化器状态字典。

在 rank 0 上合并完整的优化器状态,并按照 torch.optim.Optimizer.state_dict() 的约定(即包含键 "state""param_groups")将其作为 dict 返回。model 中包含的 FSDP 模块中的展平参数会被映射回其未展平的参数。

此方法需要在所有 rank 上调用,因为它使用了集合通信。但是,如果 rank0_only=True,则状态字典仅在 rank 0 上填充,所有其他 rank 返回一个空的 dict

torch.optim.Optimizer.state_dict() 不同,此方法使用完整的参数名称作为键,而不是参数 ID。

torch.optim.Optimizer.state_dict() 中一样,优化器状态字典中包含的张量未被克隆,因此可能存在别名问题。为了获得最佳实践,建议立即保存返回的优化器状态字典,例如使用 torch.save()

参数
  • model (torch.nn.Module) – 根模块(可能是也可能不是 FullyShardedDataParallel 实例),其参数被传递给了优化器 optim

  • optim (torch.optim.Optimizer) – 针对 model 参数的优化器。

  • optim_input (Optional[Union[List[Dict[str, Any]], Iterable[torch.nn.Parameter]]]) – 传递给优化器 optim 的输入,表示参数组的 list 或参数的可迭代对象;如果为 None,则此方法假定输入为 model.parameters()。此参数已废弃,不再需要传入。(默认值: None)

  • rank0_only (bool) – 如果为 True,则仅在 rank 0 上保存填充的 dict;如果为 False,则在所有 rank 上保存。(默认值: True)

  • group (dist.ProcessGroup) – 参数分片所在的模型的进程组,如果使用默认进程组则为 None。(默认值: None)

返回

一个 dict,包含 model 原始未展平参数的优化器状态,并按照 torch.optim.Optimizer.state_dict() 的约定包含键 “state” 和 “param_groups”。如果 rank0_only=True,则非零 rank 返回一个空的 dict

返回类型

Dict[str, Any]

static get_state_dict_type(module)[source][source]

获取以 module 为根的 FSDP 模块的 state_dict 类型和相应的配置。

目标模块不必是 FSDP 模块。

返回

一个 StateDictSettings 对象,包含当前设置的 state_dict 类型以及 state_dict / optim_state_dict 配置。

抛出
  • AssertionError` 如果不同

  • FSDP 子模块的 StateDictSettings 不同。

返回类型

StateDictSettings

property module: Module

返回包装的模块。

named_buffers(*args, **kwargs)[source][source]

返回一个模块缓冲区迭代器,生成缓冲区的名称和缓冲区本身。

summon_full_params() 上下文管理器内部时,拦截缓冲区名称并移除所有出现的 FSDP 特定的展平缓冲区前缀。

返回类型

Iterator[tuple[str, torch.Tensor]]

named_parameters(*args, **kwargs)[source][source]

返回一个模块参数迭代器,生成参数的名称和参数本身。

summon_full_params() 上下文管理器内部时,拦截参数名称并移除所有出现的 FSDP 特定的展平参数前缀。

返回类型

Iterator[tuple[str, torch.nn.parameter.Parameter]]

no_sync()[source][source]

禁用 FSDP 实例之间的梯度同步。

在此上下文内,梯度将累积在模块变量中,稍后在退出上下文后的第一次前向-反向传播中进行同步。此方法应仅在根 FSDP 实例上使用,并将递归应用于所有子 FSDP 实例。

注意

这可能会导致更高的内存使用,因为 FSDP 将累积完整的模型梯度(而不是梯度分片),直到最终同步。

注意

与 CPU 卸载一起使用时,在上下文管理器内部时,梯度不会被卸载到 CPU。相反,它们只会在最终同步之后立即卸载。

返回类型

Generator

static optim_state_dict(model, optim, optim_state_dict=None, group=None)[source][source]

转换与分片模型对应的优化器的状态字典。

给定的状态字典可以被转换为以下三种类型之一:1) 完整的优化器状态字典,2) 分片优化器状态字典,3) 本地优化器状态字典。

对于完整的优化器状态字典,所有状态都是未展平且未分片的。可以通过 state_dict_type() 指定仅在 Rank 0 上和仅在 CPU 上来避免内存溢出 (OOM)。

对于分片优化器状态字典,所有状态都是未展平但已分片的。可以通过 state_dict_type() 指定仅在 CPU 上来进一步节省内存。

对于本地状态字典,不会执行任何转换。但会将状态从 nn.Tensor 转换为 ShardedTensor 以表示其分片性质(目前尚不支持此功能)。

示例

>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
>>> from torch.distributed.fsdp import StateDictType
>>> from torch.distributed.fsdp import FullStateDictConfig
>>> from torch.distributed.fsdp import FullOptimStateDictConfig
>>> # Save a checkpoint
>>> model, optim = ...
>>> FSDP.set_state_dict_type(
>>>     model,
>>>     StateDictType.FULL_STATE_DICT,
>>>     FullStateDictConfig(rank0_only=False),
>>>     FullOptimStateDictConfig(rank0_only=False),
>>> )
>>> state_dict = model.state_dict()
>>> optim_state_dict = FSDP.optim_state_dict(model, optim)
>>> save_a_checkpoint(state_dict, optim_state_dict)
>>> # Load a checkpoint
>>> model, optim = ...
>>> state_dict, optim_state_dict = load_a_checkpoint()
>>> FSDP.set_state_dict_type(
>>>     model,
>>>     StateDictType.FULL_STATE_DICT,
>>>     FullStateDictConfig(rank0_only=False),
>>>     FullOptimStateDictConfig(rank0_only=False),
>>> )
>>> model.load_state_dict(state_dict)
>>> optim_state_dict = FSDP.optim_state_dict_to_load(
>>>     model, optim, optim_state_dict
>>> )
>>> optim.load_state_dict(optim_state_dict)
参数
  • model (torch.nn.Module) – 根模块(可能是也可能不是 FullyShardedDataParallel 实例),其参数被传递给了优化器 optim

  • optim (torch.optim.Optimizer) – 针对 model 参数的优化器。

  • optim_state_dict (Dict[str, Any]) – 要转换的目标优化器状态字典。如果该值为 None,将使用 optim.state_dict()。(默认值: None)

  • group (dist.ProcessGroup) – 参数分片所在的模型的进程组,如果使用默认进程组则为 None。(默认值: None)

返回

一个 dict,包含 model 的优化器状态。优化器状态的分片基于 state_dict_type

返回类型

Dict[str, Any]

static optim_state_dict_to_load(model, optim, optim_state_dict, is_named_optimizer=False, load_directly=False, group=None)[source][source]

转换优化器状态字典,使其可以加载到与 FSDP 模型关联的优化器中。

给定一个通过 optim_state_dict() 转换而来的 optim_state_dict,此方法会将其转换为展平的优化器状态字典,以便加载到 `optim` 中,`optim` 是 `model` 的优化器。model 必须由 FullyShardedDataParallel 进行分片。

>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
>>> from torch.distributed.fsdp import StateDictType
>>> from torch.distributed.fsdp import FullStateDictConfig
>>> from torch.distributed.fsdp import FullOptimStateDictConfig
>>> # Save a checkpoint
>>> model, optim = ...
>>> FSDP.set_state_dict_type(
>>>     model,
>>>     StateDictType.FULL_STATE_DICT,
>>>     FullStateDictConfig(rank0_only=False),
>>>     FullOptimStateDictConfig(rank0_only=False),
>>> )
>>> state_dict = model.state_dict()
>>> original_osd = optim.state_dict()
>>> optim_state_dict = FSDP.optim_state_dict(
>>>     model,
>>>     optim,
>>>     optim_state_dict=original_osd
>>> )
>>> save_a_checkpoint(state_dict, optim_state_dict)
>>> # Load a checkpoint
>>> model, optim = ...
>>> state_dict, optim_state_dict = load_a_checkpoint()
>>> FSDP.set_state_dict_type(
>>>     model,
>>>     StateDictType.FULL_STATE_DICT,
>>>     FullStateDictConfig(rank0_only=False),
>>>     FullOptimStateDictConfig(rank0_only=False),
>>> )
>>> model.load_state_dict(state_dict)
>>> optim_state_dict = FSDP.optim_state_dict_to_load(
>>>     model, optim, optim_state_dict
>>> )
>>> optim.load_state_dict(optim_state_dict)
参数
  • model (torch.nn.Module) – 根模块(可能是也可能不是 FullyShardedDataParallel 实例),其参数被传递给了优化器 optim

  • optim (torch.optim.Optimizer) – 针对 model 参数的优化器。

  • optim_state_dict (Dict[str, Any]) – 待加载的优化器状态。

  • is_named_optimizer (bool) – 此优化器是否为 NamedOptimizer 或 KeyedOptimizer。仅当 optim 是 TorchRec 的 KeyedOptimizer 或 torch.distributed 的 NamedOptimizer 时,才设置为 True。

  • load_directly (bool) – 如果设置为 True,此 API 在返回结果之前也会调用 optim.load_state_dict(result)。否则,用户需要负责调用 optim.load_state_dict()。(默认值: False)

  • group (dist.ProcessGroup) – 参数分片所在的模型的进程组,如果使用默认进程组则为 None。(默认值: None)

返回类型

dict[str, Any]

register_comm_hook(state, hook)[source][source]

注册一个通信钩子。

这是一项增强功能,为用户提供了一个灵活的钩子,用户可以在其中指定 FSDP 如何聚合多个 worker 上的梯度。此钩子可用于实现诸如 GossipGrad 和梯度压缩等多种算法,这些算法在使用 FullyShardedDataParallel 训练时涉及不同的参数同步通信策略。

警告

FSDP 通信钩子应在运行初始前向传播之前注册,并且只能注册一次。

参数
  • state (object) –

    传递给钩子,用于在训练过程中维护任何状态信息。示例包括梯度压缩中的错误反馈、GossipGrad 中下一个要通信的对等方等。它由每个 worker 本地存储,并由 worker 上的所有梯度张量共享。

  • hook (Callable) – 可调用对象,具有以下签名之一:1) hook: Callable[torch.Tensor] -> None:此函数接受一个 Python 张量,该张量表示与此 FSDP 单元包装的模型(未被其他 FSDP 子单元包装的变量)对应的所有变量的完整的、展平的、未分片的梯度。然后执行所有必要的处理并返回 None;2) hook: Callable[torch.Tensor, torch.Tensor] -> None:此函数接受两个 Python 张量,第一个张量表示与此 FSDP 单元包装的模型(未被其他 FSDP 子单元包装的变量)对应的所有变量的完整的、展平的、未分片的梯度。后者表示一个预设大小的张量,用于存储归约后的分片梯度的一个块。在这两种情况下,可调用对象都会执行所有必要的处理并返回 None。签名 1 的可调用对象预期处理 NO_SHARD 情况下的梯度通信。签名 2 的可调用对象预期处理分片情况下的梯度通信。

static rekey_optim_state_dict(optim_state_dict, optim_state_key_type, model, optim_input=None, optim=None)[source][source]

重设优化器状态字典 optim_state_dict 的键,以使用键类型 optim_state_key_type

这可用于实现在具有 FSDP 实例的模型和没有 FSDP 实例的模型之间优化器状态字典的兼容性。

将 FSDP 完整优化器状态字典(即来自 full_optim_state_dict())的键重设为使用参数 ID,以便可以加载到未包装的模型

>>> wrapped_model, wrapped_optim = ...
>>> full_osd = FSDP.full_optim_state_dict(wrapped_model, wrapped_optim)
>>> nonwrapped_model, nonwrapped_optim = ...
>>> rekeyed_osd = FSDP.rekey_optim_state_dict(full_osd, OptimStateKeyType.PARAM_ID, nonwrapped_model)
>>> nonwrapped_optim.load_state_dict(rekeyed_osd)

将来自未包装的模型的普通优化器状态字典的键重设,以便可以加载到包装的模型

>>> nonwrapped_model, nonwrapped_optim = ...
>>> osd = nonwrapped_optim.state_dict()
>>> rekeyed_osd = FSDP.rekey_optim_state_dict(osd, OptimStateKeyType.PARAM_NAME, nonwrapped_model)
>>> wrapped_model, wrapped_optim = ...
>>> sharded_osd = FSDP.shard_full_optim_state_dict(rekeyed_osd, wrapped_model)
>>> wrapped_optim.load_state_dict(sharded_osd)
返回

使用由 optim_state_key_type 指定的参数键重设键的优化器状态字典。

返回类型

Dict[str, Any]

static scatter_full_optim_state_dict(full_optim_state_dict, model, optim_input=None, optim=None, group=None)[source][source]

将完整的优化器状态字典从 rank 0 散射到所有其他 rank。

返回每个 rank 上的分片优化器状态字典。返回值与 shard_full_optim_state_dict() 相同,并且在 rank 0 上,第一个参数应为 full_optim_state_dict() 的返回值。

示例

>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
>>> model, optim = ...
>>> full_osd = FSDP.full_optim_state_dict(model, optim)  # only non-empty on rank 0
>>> # Define new model with possibly different world size
>>> new_model, new_optim, new_group = ...
>>> sharded_osd = FSDP.scatter_full_optim_state_dict(full_osd, new_model, group=new_group)
>>> new_optim.load_state_dict(sharded_osd)

注意

shard_full_optim_state_dict()scatter_full_optim_state_dict() 都可以用于获取要加载的分片优化器状态字典。假设完整的优化器状态字典位于 CPU 内存中,前者要求每个 rank 在 CPU 内存中拥有完整的字典,并且每个 rank 单独分片字典而无需任何通信;后者仅要求 rank 0 在 CPU 内存中拥有完整的字典,并且 rank 0 将每个分片移动到 GPU 内存(对于 NCCL)并适当地将其通信到各个 rank。因此,前者具有更高的总 CPU 内存成本,而后者具有更高的通信成本。

参数
  • full_optim_state_dict (Optional[Dict[str, Any]]) – 如果在 rank 0 上,则为对应于未展平参数并包含完整的非分片优化器状态的优化器状态字典;此参数在非零 rank 上会被忽略。

  • model (torch.nn.Module) – 根模块(可能是一个 FullyShardedDataParallel 实例,也可能不是),其参数对应于 full_optim_state_dict 中的优化器状态。

  • optim_input (Optional[Union[List[Dict[str, Any]], Iterable[torch.nn.Parameter]]]) – 传递给优化器的输入,表示参数组的 list 或参数的可迭代对象;如果为 None,则此方法假定输入是 model.parameters()。此参数已弃用,不再需要传入。(默认值:None

  • optim (Optional[torch.optim.Optimizer]) – 将加载此方法返回的状态字典的优化器。这是比 optim_input 更推荐使用的参数。(默认值:None

  • group (dist.ProcessGroup) – 参数分片所在的模型的进程组,如果使用默认进程组则为 None。(默认值: None)

返回

完整的优化器状态字典现在已重新映射到展平参数而不是未展平参数,并且仅包含此 rank 的优化器状态部分。

返回类型

Dict[str, Any]

static set_state_dict_type(module, state_dict_type, state_dict_config=None, optim_state_dict_config=None)[source][source]

设置目标模块所有后代 FSDP 模块的 state_dict_type

还接受(可选)的模型和优化器状态字典配置。目标模块不必是 FSDP 模块。如果目标模块是 FSDP 模块,其 state_dict_type 也会被更改。

注意

此 API 仅应在顶级(根)模块上调用。

注意

在根 FSDP 模块被另一个 nn.Module 包装的情况下,此 API 使用户能够透明地使用传统的 state_dict API 来进行模型检查点。例如,以下代码将确保在所有非 FSDP 实例上调用 state_dict,同时对于 FSDP 实例分派到 sharded_state_dict 实现:

示例

>>> model = DDP(FSDP(...))
>>> FSDP.set_state_dict_type(
>>>     model,
>>>     StateDictType.SHARDED_STATE_DICT,
>>>     state_dict_config = ShardedStateDictConfig(offload_to_cpu=True),
>>>     optim_state_dict_config = OptimStateDictConfig(offload_to_cpu=True),
>>> )
>>> param_state_dict = model.state_dict()
>>> optim_state_dict = FSDP.optim_state_dict(model, optim)
参数
  • module (torch.nn.Module) – 根模块。

  • state_dict_type (StateDictType) – 期望设置的 state_dict_type

  • state_dict_config (Optional[StateDictConfig]) – 目标 state_dict_type 的配置。

  • optim_state_dict_config (Optional[OptimStateDictConfig]) – 优化器状态字典的配置。

返回

一个 StateDictSettings,包含模块之前的 state_dict 类型和配置。

返回类型

StateDictSettings

static shard_full_optim_state_dict(full_optim_state_dict, model, optim_input=None, optim=None)[source][source]

对完整的优化器状态字典进行分片。

full_optim_state_dict 中的状态重新映射到展平参数,而不是未展平参数,并且仅限于此 rank 的优化器状态部分。第一个参数应该是 full_optim_state_dict() 的返回值。

示例

>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
>>> model, optim = ...
>>> full_osd = FSDP.full_optim_state_dict(model, optim)
>>> torch.save(full_osd, PATH)
>>> # Define new model with possibly different world size
>>> new_model, new_optim = ...
>>> full_osd = torch.load(PATH)
>>> sharded_osd = FSDP.shard_full_optim_state_dict(full_osd, new_model)
>>> new_optim.load_state_dict(sharded_osd)

注意

shard_full_optim_state_dict()scatter_full_optim_state_dict() 都可以用于获取要加载的分片优化器状态字典。假设完整的优化器状态字典位于 CPU 内存中,前者要求每个 rank 在 CPU 内存中拥有完整的字典,并且每个 rank 单独分片字典而无需任何通信;后者仅要求 rank 0 在 CPU 内存中拥有完整的字典,并且 rank 0 将每个分片移动到 GPU 内存(对于 NCCL)并适当地将其通信到各个 rank。因此,前者具有更高的总 CPU 内存成本,而后者具有更高的通信成本。

参数
  • full_optim_state_dict (Dict[str, Any]) – 对应于未展平参数并包含完整的非分片优化器状态的优化器状态字典。

  • model (torch.nn.Module) – 根模块(可能是一个 FullyShardedDataParallel 实例,也可能不是),其参数对应于 full_optim_state_dict 中的优化器状态。

  • optim_input (Optional[Union[List[Dict[str, Any]], Iterable[torch.nn.Parameter]]]) – 传递给优化器的输入,表示参数组的 list 或参数的可迭代对象;如果为 None,则此方法假定输入是 model.parameters()。此参数已弃用,不再需要传入。(默认值:None

  • optim (Optional[torch.optim.Optimizer]) – 将加载此方法返回的状态字典的优化器。这是比 optim_input 更推荐使用的参数。(默认值:None

返回

完整的优化器状态字典现在已重新映射到展平参数而不是未展平参数,并且仅包含此 rank 的优化器状态部分。

返回类型

Dict[str, Any]

static sharded_optim_state_dict(model, optim, group=None)[source][source]

返回分片形式的优化器状态字典。

此 API 类似于 full_optim_state_dict(),但此 API 将所有非零维度状态分块到 ShardedTensor 以节省内存。此 API 仅应在模型 state_dict 使用上下文管理器 with state_dict_type(SHARDED_STATE_DICT): 获取时使用。

有关详细用法,请参考 full_optim_state_dict()

警告

返回的状态字典包含 ShardedTensor,不能直接用于常规的 optim.load_state_dict

返回类型

dict[str, Any]

static state_dict_type(module, state_dict_type, state_dict_config=None, optim_state_dict_config=None)[source][source]

设置目标模块所有后代 FSDP 模块的 state_dict_type

此上下文管理器与 set_state_dict_type() 具有相同的功能。详细信息请阅读 set_state_dict_type() 的文档。

示例

>>> model = DDP(FSDP(...))
>>> with FSDP.state_dict_type(
>>>     model,
>>>     StateDictType.SHARDED_STATE_DICT,
>>> ):
>>>     checkpoint = model.state_dict()
参数
  • module (torch.nn.Module) – 根模块。

  • state_dict_type (StateDictType) – 期望设置的 state_dict_type

  • state_dict_config (Optional[StateDictConfig]) – 目标 state_dict_type 的模型 state_dict 配置。

  • optim_state_dict_config (Optional[OptimStateDictConfig]) – 目标 state_dict_type 的优化器 state_dict 配置。

返回类型

Generator

static summon_full_params(module, recurse=True, writeback=True, rank0_only=False, offload_to_cpu=False, with_grads=False)[source][source]

使用此上下文管理器暴露 FSDP 实例的完整参数。

在模型的前向/反向传播之后可能很有用,以获取参数进行额外的处理或检查。它可以接受非 FSDP 模块,并且将根据 recurse 参数,为其包含的所有 FSDP 模块及其子模块召唤完整参数。

注意

这可以在内部 FSDP 上使用。

注意

不能在前向或反向传播过程中使用。也不能在此上下文管理器内部启动前向或反向传播。

注意

上下文管理器退出后,参数将恢复其本地分片,存储行为与前向传播相同。

注意

可以修改完整参数,但上下文管理器退出后,只有对应于本地参数分片的部分会持久化(除非 writeback=False,在这种情况下更改将被丢弃)。在 FSDP 不对参数进行分片的情况下,目前仅在 world_size == 1NO_SHARD 配置时,无论 writeback 如何,修改都会持久化。

注意

此方法适用于本身不是 FSDP 但可能包含多个独立 FSDP 单元的模块。在这种情况下,给定的参数将应用于所有包含的 FSDP 单元。

警告

请注意,rank0_only=Truewriteback=True 结合使用目前不受支持,并将引发错误。这是因为在上下文中,模型参数形状在不同 rank 之间会不同,对其进行写入可能导致上下文退出时不同 rank 之间的数据不一致。

警告

请注意,offload_to_cpurank0_only=False 将导致完整参数被冗余地复制到同一机器上 GPU 所驻留的 CPU 内存中,这可能导致 CPU OOM 的风险。建议将 offload_to_cpurank0_only=True 一起使用。

参数
  • recurse (bool, Optional) – 递归地召唤嵌套 FSDP 实例的所有参数(默认值:True)。

  • writeback (bool, Optional) – 如果为 False,则参数修改在上下文管理器退出后将被丢弃;禁用此选项可以稍微提高效率(默认值:True)。

  • rank0_only (bool, Optional) – 如果为 True,则仅在全局 rank 0 上实例化完整参数。这意味着在上下文中,只有 rank 0 具有完整参数,而其他 rank 具有分片参数。请注意,将 rank0_only=Truewriteback=True 一起设置不受支持,因为在上下文中,模型参数形状在不同 rank 之间会不同,对其进行写入可能导致上下文退出时不同 rank 之间的数据不一致。

  • offload_to_cpu (bool, Optional) – 如果为 True,则将完整参数卸载到 CPU。请注意,这种卸载目前仅在参数被分片时才会发生(只有在 world_size = 1 或 NO_SHARD 配置时才不会发生)。建议将 offload_to_cpurank0_only=True 一起使用,以避免将模型参数的冗余副本卸载到同一 CPU 内存中。

  • with_grads (bool, Optional) – 如果为 True,则也随参数一起解除梯度的分片。目前,仅在使用 use_orig_params=True 传递给 FSDP 构造函数并将 offload_to_cpu=False 传递给此方法时支持此功能。(默认值:False

返回类型

Generator

class torch.distributed.fsdp.BackwardPrefetch(value)[source][source]

此配置显式反向预取,通过在反向传播中启用通信和计算重叠来提高吞吐量,代价是略微增加内存使用量。

  • BACKWARD_PRE:这实现了最多的重叠,但内存使用量增加得也最多。它在当前参数集的梯度计算之前预取下一组参数。这重叠了下一次 all-gather当前梯度计算,并且在峰值时,它在内存中保存当前参数集、下一组参数和当前梯度集。

  • BACKWARD_POST:这实现了较少的重叠,但需要较少的内存使用量。它在当前参数集的梯度计算之后预取下一组参数。这重叠了当前 reduce-scatter下一次梯度计算,并且它在为下一组参数分配内存之前释放当前参数集,在峰值时仅在内存中保存下一组参数和当前梯度集。

  • FSDP 的 backward_prefetch 参数接受 None,它完全禁用反向预取。这没有重叠,也不会增加内存使用量。通常,我们不推荐此设置,因为它可能会显著降低吞吐量。

更多技术背景:对于使用 NCCL 后端的单个进程组,任何集合操作,即使来自不同的流,也会争夺相同的设备级 NCCL 流,这意味着集合操作发出的相对顺序对重叠很重要。这两个反向预取值对应于不同的发出顺序。

class torch.distributed.fsdp.ShardingStrategy(value)[source][source]

这指定了 FullyShardedDataParallel 用于分布式训练的分片策略。

  • FULL_SHARD:参数、梯度和优化器状态都进行分片。对于参数,此策略在前向传播之前取消分片(通过 all-gather),在前向传播之后重新分片,在反向计算之前取消分片,并在反向计算之后重新分片。对于梯度,它在反向计算之后同步并分片(通过 reduce-scatter)。分片优化器状态在每个 rank 上本地更新。

  • SHARD_GRAD_OP:梯度和优化器状态在计算过程中分片,此外,参数在计算之外分片。对于参数,此策略在前向传播之前取消分片,在前向传播之后不重新分片,仅在反向计算之后重新分片。分片优化器状态在每个 rank 上本地更新。在 no_sync() 内部,参数在反向计算之后不重新分片。

  • NO_SHARD:参数、梯度和优化器状态不进行分片,而是像 PyTorch 的 DistributedDataParallel API 一样在 rank 之间复制。对于梯度,此策略在反向计算之后同步(通过 all-reduce)。未分片优化器状态在每个 rank 上本地更新。

  • HYBRID_SHARD:在节点内部应用 FULL_SHARD,并在节点之间复制参数。这减少了通信量,因为昂贵的 all-gather 和 reduce-scatter 只在节点内部完成,对于中等大小的模型可能更高效。

  • _HYBRID_SHARD_ZERO2:在节点内部应用 SHARD_GRAD_OP,并在节点之间复制参数。这类似于 HYBRID_SHARD,但可能提供更高的吞吐量,因为未分片参数在前向传播后不会被释放,从而节省了反向传播前的 all-gather。

class torch.distributed.fsdp.MixedPrecision(param_dtype=None, reduce_dtype=None, buffer_dtype=None, keep_low_precision_grads=False, cast_forward_inputs=False, cast_root_forward_inputs=True, _module_classes_to_ignore=(<class 'torch.nn.modules.batchnorm._BatchNorm'>, ))[source][source]

这配置了 FSDP 原生的混合精度训练。

变量
  • param_dtype (Optional[torch.dtype]) – 这指定了在前向和反向传播过程中模型参数的数据类型(dtype),从而也指定了前向和反向计算的数据类型。在前向和反向传播之外,分片参数以全精度保存(例如,用于优化器步骤),并且对于模型检查点,参数始终以全精度保存。(默认值:None

  • reduce_dtype (Optional[torch.dtype]) – 这指定了梯度归约(即 reduce-scatter 或 all-reduce)的数据类型。如果此参数为 Noneparam_dtype 不为 None,则此参数将采用 param_dtype 的值,仍然在低精度下进行梯度归约。此参数允许与 param_dtype 不同,例如强制梯度归约以全精度运行。(默认值:None

  • buffer_dtype (Optional[torch.dtype]) – 这指定了缓冲区的数据类型。FSDP 不分片缓冲区。相反,FSDP 在第一次前向传播中将它们转换为 buffer_dtype 并此后保持该数据类型。对于模型检查点,缓冲区以全精度保存,除了 LOCAL_STATE_DICT。(默认值:None

  • keep_low_precision_grads (bool) – 如果为 False,FSDP 在反向传播后将梯度转换为全精度,为优化器步骤做准备。如果为 True,FSDP 将梯度保持在用于梯度归约的数据类型中,如果使用支持在低精度下运行的自定义优化器,则可以节省内存。(默认值:False

  • cast_forward_inputs (bool) – 如果为 True,则此 FSDP 模块将其前向传播的 args 和 kwargs 转换为 param_dtype。这是为了确保参数和输入数据类型在前向计算中匹配,这是许多操作所要求的。当仅对部分 FSDP 模块应用混合精度时,可能需要将其设置为 True,在这种情况下,混合精度 FSDP 子模块需要重新转换其输入。(默认值:False

  • cast_root_forward_inputs (bool) – 如果为 True,则根 FSDP 模块将其前向传播的 args 和 kwargs 转换为 param_dtype,从而覆盖 cast_forward_inputs 的值。对于非根 FSDP 模块,此参数不起作用。(默认值:True

  • _module_classes_to_ignore (collections.abc.Sequence[type[torch.nn.modules.module.Module]]) – (Sequence[Type[nn.Module]]): 这指定了在使用 auto_wrap_policy 时要忽略混合精度的模块类:这些类的模块将单独应用 FSDP,并且混合精度被禁用(这意味着最终的 FSDP 构建将偏离指定的策略)。如果未指定 auto_wrap_policy,则此参数不起作用。此 API 是实验性的,可能会发生变化。(默认值:(_BatchNorm,)

注意

此 API 是实验性的,可能会发生变化。

注意

只有浮点张量会被转换为指定的数据类型。

注意

summon_full_params 中,参数被强制转换为全精度,但缓冲区不会。

注意

Layer norm 和 batch norm 即使输入是低精度(如 float16bfloat16)也会在 float32 中累积。对这些 norm 模块禁用 FSDP 的混合精度仅意味着仿射参数保留在 float32 中。然而,这会导致这些 norm 模块产生单独的 all-gather 和 reduce-scatter,这可能效率低下,因此如果工作负载允许,用户应优先仍然对这些模块应用混合精度。

注意

默认情况下,如果用户传递的模型包含任何 _BatchNorm 模块并指定了 auto_wrap_policy,则 batch norm 模块将单独应用 FSDP,并且混合精度被禁用。请参阅 _module_classes_to_ignore 参数。

注意

MixedPrecision 默认设置 cast_root_forward_inputs=Truecast_forward_inputs=False。对于根 FSDP 实例,其 cast_root_forward_inputs 优先于其 cast_forward_inputs。对于非根 FSDP 实例,其 cast_root_forward_inputs 值被忽略。默认设置对于典型情况已经足够,即每个 FSDP 实例具有相同的 MixedPrecision 配置,并且只需要在模型前向传播的开始时将输入转换为 param_dtype

注意

对于具有不同 MixedPrecision 配置的嵌套 FSDP 实例,我们建议设置单独的 cast_forward_inputs 值来配置是否在每个实例的前向传播之前转换输入。在这种情况下,由于转换发生在每个 FSDP 实例的前向传播之前,父 FSDP 实例应在其 FSDP 子模块之前运行其非 FSDP 子模块,以避免激活的数据类型因不同的 MixedPrecision 配置而改变。

示例

>>> model = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3))
>>> model[1] = FSDP(
>>>     model[1],
>>>     mixed_precision=MixedPrecision(param_dtype=torch.float16, cast_forward_inputs=True),
>>> )
>>> model = FSDP(
>>>     model,
>>>     mixed_precision=MixedPrecision(param_dtype=torch.bfloat16, cast_forward_inputs=True),
>>> )

上面展示了一个工作示例。另一方面,如果 model[1]model[0] 替换,意味着使用不同 MixedPrecision 的子模块首先运行其前向传播,那么 model[1] 将错误地看到 float16 激活而不是 bfloat16 激活。

torch.distributed.fsdp.CPUOffload(offload_params=False)[source][source]

这用于配置 CPU 卸载。

变量

offload_params (布尔值) – 这指定了在不参与计算时是否将参数卸载到 CPU。如果 True,则也会将梯度卸载到 CPU,这意味着优化器步骤在 CPU 上运行。

torch.distributed.fsdp.StateDictConfig(offload_to_cpu=False)[source][source]

StateDictConfig 是所有 state_dict 配置类的基类。用户应实例化一个子类(例如 FullStateDictConfig)来配置 FSDP 支持的相应 state_dict 类型的设置。

变量

offload_to_cpu (布尔值) – 如果为 True,则 FSDP 将 state dict 的值卸载到 CPU;如果为 False,则 FSDP 将它们保留在 GPU 上。(默认值:False)

torch.distributed.fsdp.FullStateDictConfig(offload_to_cpu=False, rank0_only=False)[source][source]

FullStateDictConfig 是用于 StateDictType.FULL_STATE_DICT 的配置类。我们建议在保存完整 state dict 时同时启用 offload_to_cpu=Truerank0_only=True,以分别节省 GPU 内存和 CPU 内存。此类配置旨在通过 state_dict_type() 上下文管理器使用,如下所示:

>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
>>> fsdp = FSDP(model, auto_wrap_policy=...)
>>> cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
>>> with FSDP.state_dict_type(fsdp, StateDictType.FULL_STATE_DICT, cfg):
>>>     state = fsdp.state_dict()
>>> # `state` will be empty on non rank 0 and contain CPU tensors on rank 0.
>>> # To reload checkpoint for inference, finetuning, transfer learning, etc:
>>> model = model_fn()  # Initialize model in preparation for wrapping with FSDP
>>> if dist.get_rank() == 0:
>>> # Load checkpoint only on rank 0 to avoid memory redundancy
>>>     state_dict = torch.load("my_checkpoint.pt")
>>>     model.load_state_dict(state_dict)
>>> # All ranks initialize FSDP module as usual. `sync_module_states` argument
>>> # communicates loaded checkpoint states from rank 0 to rest of the world.
>>> fsdp = FSDP(
...     model,
...     device_id=torch.cuda.current_device(),
...     auto_wrap_policy=...,
...     sync_module_states=True,
... )
>>> # After this point, all ranks have FSDP model with loaded checkpoint.
变量

rank0_only (布尔值) – 如果为 True,则只有 rank 0 保存完整的 state dict,而非零 rank 保存一个空字典。如果为 False,则所有 rank 都保存完整的 state dict。(默认值:False)

torch.distributed.fsdp.ShardedStateDictConfig(offload_to_cpu=False, _use_dtensor=False)[source][source]

ShardedStateDictConfig 是用于 StateDictType.SHARDED_STATE_DICT 的配置类。

变量

_use_dtensor (布尔值) – 如果为 True,则 FSDP 将 state dict 的值保存为 DTensor;如果为 False,则 FSDP 将它们保存为 ShardedTensor。(默认值:False)

警告

_use_dtensorShardedStateDictConfig 的一个私有字段,FSDP 用它来确定 state dict 值的类型。用户不应手动修改 _use_dtensor

torch.distributed.fsdp.LocalStateDictConfig(offload_to_cpu: bool = False)[source][source]
torch.distributed.fsdp.OptimStateDictConfig(offload_to_cpu=True)[source][source]

OptimStateDictConfig 是所有 optim_state_dict 配置类的基类。用户应实例化一个子类(例如 FullOptimStateDictConfig)来配置 FSDP 支持的相应 optim_state_dict 类型的设置。

变量

offload_to_cpu (布尔值) – 如果为 True,则 FSDP 将 state dict 的 tensor 值卸载到 CPU;如果为 False,则 FSDP 将它们保留在原始设备上(如果未启用参数 CPU 卸载,则为 GPU)。(默认值:True)

torch.distributed.fsdp.FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=False)[source][source]
变量

rank0_only (布尔值) – 如果为 True,则只有 rank 0 保存完整的 state dict,而非零 rank 保存一个空字典。如果为 False,则所有 rank 都保存完整的 state dict。(默认值:False)

torch.distributed.fsdp.ShardedOptimStateDictConfig(offload_to_cpu=True, _use_dtensor=False)[source][source]

ShardedOptimStateDictConfig 是用于 StateDictType.SHARDED_STATE_DICT 的配置类。

变量

_use_dtensor (布尔值) – 如果为 True,则 FSDP 将 state dict 的值保存为 DTensor;如果为 False,则 FSDP 将它们保存为 ShardedTensor。(默认值:False)

警告

_use_dtensorShardedOptimStateDictConfig 的一个私有字段,FSDP 用它来确定 state dict 值的类型。用户不应手动修改 _use_dtensor

torch.distributed.fsdp.LocalOptimStateDictConfig(offload_to_cpu: bool = False)[source][source]
torch.distributed.fsdp.StateDictSettings(state_dict_type: torch.distributed.fsdp.api.StateDictType, state_dict_config: torch.distributed.fsdp.api.StateDictConfig, optim_state_dict_config: torch.distributed.fsdp.api.OptimStateDictConfig)[source][source]

文档

查阅 PyTorch 的完整开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题解答

查看资源