• 文档 >
  • 分布式检查点 - torch.distributed.checkpoint
快捷方式

分布式检查点 - torch.distributed.checkpoint

分布式检查点 (DCP) 支持从多个 rank 并行加载和保存模型。它处理加载时重新分片,从而支持在一个集群拓扑中保存并在另一个集群拓扑中加载。

DCP 与 torch.savetorch.load 在几个重要方面有所不同

  • 它为每个检查点生成多个文件,每个 rank 至少一个文件。

  • 它就地操作,意味着模型应首先分配其数据,而 DCP 使用该存储。

加载和保存检查点的入口点如下

其他资源:

torch.distributed.checkpoint.state_dict_saver.save(state_dict, *, checkpoint_id=None, storage_writer=None, planner=None, process_group=None)[source][source]

以 SPMD 风格保存分布式模型。

此函数与 torch.save() 不同,因为它通过让每个 rank 仅保存其本地分片来处理 ShardedTensorDTensor

对于每个 Stateful 对象(同时具有 state_dictload_state_dict),save 将在序列化之前调用 state_dict

警告

不保证跨 PyTorch 版本的已保存 state_dict 的向后兼容性。

警告

如果使用 process_group 参数,请确保只有其 rank 调用 save_state_dict,并且 state_dict 中的所有数据都属于它。

注意

当为 FSDP 的 ShardingStrategy.HYBRID_SHARD 保存检查点时,只有 shard_group 中的一个应该调用 save_state_dict,并且需要传入相应的进程组。

注意

如果没有可用的进程组,此函数假定意图是保存在

本地进程中的 state_dict。

参数
  • state_dict (Dict[str, Any]) – 要保存的 state_dict。

  • checkpoint_id (Union[str, os.PathLike, None]) – 此检查点实例的 ID。 checkpoint_id 的含义取决于存储。它可以是文件夹或文件的路径。如果存储是键值存储,则它也可以是键。(默认值:None

  • storage_writer (Optional[StorageWriter]) – 用于执行写入的 StorageWriter 实例。如果未指定,DCP 将根据 checkpoint_id 自动推断 writer。如果 checkpoint_id 也为 None,则会引发异常。(默认值:None

  • planner (Optional[SavePlanner]) – SavePlanner 实例。如果未指定,则将使用默认 planner。(默认值:None

  • process_group (Optional[ProcessGroup]) – 用于跨 rank 同步的 ProcessGroup。(默认值:None

返回

已保存检查点的元数据对象。

返回类型

元数据

示例

>>> my_model = MyModule()
>>> state_dict = {"model": my_model}
>>> fs_storage_writer = torch.distributed.checkpoint.FileSystemWriter("/checkpoint/1")
>>> torch.distributed.checkpoint.save(
>>>     state_dict=state_dict,
>>>     storage_writer=fs_storage_writer,
>>> )

注意

save_state_dict 使用 collectives 来协调跨 rank 的写入。对于基于 NCCL 的进程组,对象的内部张量表示必须先移动到 GPU 设备,然后才能进行通信。在这种情况下,使用的设备由 torch.cuda.current_device() 给出,用户有责任确保将其设置为使每个 rank 都有一个单独的 GPU,通过 torch.cuda.set_device()

torch.distributed.checkpoint.state_dict_saver.async_save(state_dict, *, checkpoint_id=None, storage_writer=None, planner=None, process_group=None)[source][source]

save 的异步版本。此代码首先将 state_dict 反序列化到暂存存储(默认为 CPU 内存),然后在单独的线程中调用 save

警告

此功能是实验性的,可能会发生变化。

参数
  • state_dict (Dict[str, Any]) – 要保存的 state_dict。

  • checkpoint_id (Union[str, os.PathLike, None]) – 此检查点实例的 ID。 checkpoint_id 的含义取决于存储。它可以是文件夹或文件的路径。如果存储是键值存储,则它也可以是键。(默认值:None

  • storage_writer (Optional[StorageWriter]) – 用于执行“stage”和“save”的 StorageWriter 实例。如果未指定,DCP 将根据 checkpoint_id 自动推断 writer。如果 checkpoint_id 也为 None,则会引发异常。(默认值:None

  • planner (Optional[SavePlanner]) – SavePlanner 实例。如果未指定,则将使用默认 planner。(默认值:None

  • process_group (Optional[ProcessGroup]) – 用于跨 rank 同步的 ProcessGroup。(默认值:None

返回

一个 Future,其中包含来自 save 的结果 Metadata 对象。

返回类型

Future

示例

>>> my_model = MyModule()
>>> state_dict = {"model": my_model}
>>> fs_storage_writer = torch.distributed.checkpoint.FileSystemWriter("/checkpoint/1")
>>> checkpoint_future = torch.distributed.checkpoint.async_save(
>>>     state_dict=state_dict,
>>>     storage_writer=fs_storage_writer,
>>> )
>>>
>>> # ... do some work ...
>>>
>>> checkpoint_future.result()
torch.distributed.checkpoint.state_dict_saver.save_state_dict(state_dict, storage_writer, process_group=None, coordinator_rank=0, no_dist=False, planner=None)[source][source]

此方法已弃用。请切换到“save”。

返回类型

元数据

torch.distributed.checkpoint.state_dict_loader.load(state_dict, *, checkpoint_id=None, storage_reader=None, planner=None, process_group=None)[source][source]

以 SPMD 风格加载分布式 state_dict

每个 rank 将尝试读取满足所请求的 state_dict 所需的最少量数据。加载 ShardedTensorDTensor 实例时,每个 rank 仅读取其本地分片的数据。

对于每个 Stateful 对象(同时具有 state_dictload_state_dict),load 将首先调用 state_dict,然后再尝试反序列化,然后在反序列化完成后调用 load_state_dict。对于每个非 Stateful 对象,load 将反序列化该对象,然后在 state_dict 中将其替换为反序列化的对象。

警告

state_dict 中的所有张量都必须在调用此函数之前分配到其目标设备上。

所有非张量数据都使用 torch.load() 加载,并在 state_dict 上就地修改。

警告

用户必须在根模块上调用 load_state_dict,以确保加载后处理和非张量数据正确传播。

参数
  • state_dict (Dict[str, Any]) – 要保存的 state_dict。

  • checkpoint_id (Union[str, os.PathLike, None]) – 此检查点实例的 ID。 checkpoint_id 的含义取决于存储。它可以是文件夹或文件的路径。如果存储是键值存储,则它也可以是键。(默认值:None

  • storage_reader (Optional[StorageReader]) – 用于执行读取的 StorageWriter 实例。如果未指定,DCP 将根据 checkpoint_id 自动推断 reader。如果 checkpoint_id 也为 None,则会引发异常。(默认值:None

  • planner (Optional[LoadPlanner]) – LoadPlanner 实例。如果未指定,则将使用默认 planner。(默认值:None

  • process_group (Optional[ProcessGroup]) – 用于跨 rank 同步的 ProcessGroup。(默认值:None

返回

无。

返回类型

示例
>>> my_model = MyModule()
>>> optimizer = Adagrad(my_model.parameters())
>>> model_state_dict = my_model.state_dict()
>>> fs_storage_reader = torch.distributed.checkpoint.FileSystemReader("/checkpoint/1")
>>> torch.distributed.checkpoint.load_state_dict(
>>>     state_dict=model_state_dict,
>>>     storage_reader=fs_storage_reader,
>>> )
>>> # module.load_state_dict() function might have customized steps
>>> # to flush the state_dict, must call it to
>>> # ensure correct behavior.
>>> my_model.load_state_dict(model_state_dict)

注意

load_state_dict 使用 collectives 来协调跨 rank 的读取。对于基于 NCCL 的进程组,对象的内部张量表示必须先移动到 GPU 设备,然后才能进行通信。在这种情况下,使用的设备由 torch.cuda.current_device() 给出,用户有责任确保将其设置为使每个 rank 都有一个单独的 GPU,通过 torch.cuda.set_device()

torch.distributed.checkpoint.state_dict_loader.load_state_dict(state_dict, storage_reader, process_group=None, coordinator_rank=0, no_dist=False, planner=None)[source][source]

此方法已弃用。请切换到“load”。

以下模块对于异步检查点机制(torch.distributed.checkpoint.async_save)的额外自定义也很有用

class torch.distributed.checkpoint.staging.AsyncStager(*args, **kwargs)[source][source]

此协议旨在为 dcp.async_save 提供自定义和可扩展性,允许用户自定义如何在并行执行通常的 dcp.save 路径之前暂存数据。预期的操作顺序(在 torch.distributed.state_dict_saver.async_save 中具体定义)如下

  1. AsyncStager.stage_data(state_dict)

    此调用使 AsyncStager 有机会“暂存” state_dict。此上下文中暂存的期望和目的是创建 state dict 的“训练安全”表示,这意味着在暂存完成后对模块数据的任何更新都不应反映在此方法返回的 state dict 中。例如,在默认情况下,会在 CPU RAM 上创建整个 state dict 的副本并在此处返回,从而允许用户继续训练,而不会冒数据被序列化的风险。

  2. dcp.save 并行调用从 stage 返回的 state_dict。此调用负责

    序列化 state_dict 并将其写入存储。

  3. 如果 AsyncStager.should_synchronize_after_execute 为 True,则此方法将在序列化线程启动后立即调用,并在从 dcp.async_save 返回之前调用。如果设置为 False,则假设用户已为进一步优化训练循环中的保存延迟定义了自定义同步点(例如,通过将暂存与前向/后向传递重叠),并且用户有责任在适当的时间调用 AsyncStager.synchronize_staging

    the serialization thread starts and before returning from dcp.async_save. If this is set to False, the assumption is the user has defined a custom synchronization point for the the purpose of further optimizing save latency in the training loop (for example, by overlapping staging with the forward/backward pass), and it is the respondsibility of the user to call AsyncStager.synchronize_staging at the appropriate time.

property should_synchronize_after_execute: bool

是否在执行暂存后同步。

stage(state_dict)[source][source]

返回 state_dict 的“暂存”副本。暂存副本的期望是,它不会受到暂存调用完成后发生的任何更新的影响。

返回类型

Dict[str, Union[StatefulT, Any]]

synchronize_staging()[source][source]

如果 stage 在某种程度上是异步的,则应调用此方法以确保暂存完成并且可以安全地开始修改原始 state_dict

class torch.distributed.checkpoint.staging.BlockingAsyncStager(cache_staged_state_dict=False, type_check=False)[source][source]

AsyncStager 的一个实现,它在 CPU RAM 上暂存 state_dict 并阻塞直到复制完成。此实现还提供了一个选项,可以使用固定内存来优化暂存延迟。

注意:synchronize_staging 在这种情况下是空操作。

stage(state_dict)[source][source]

返回 CPU 上 state_dict 的副本。

返回类型

Dict[str, Union[StatefulT, Any]]

synchronize_staging()[source][source]

空操作函数,因为暂存是阻塞的。

除了上述入口点之外,如下所述的 Stateful 对象在保存/加载期间提供额外的自定义。 automodule:: torch.distributed.checkpoint.stateful

class torch.distributed.checkpoint.stateful.Stateful(*args, **kwargs)[source][source]

用于可以检查点和恢复的对象的有状态协议。

load_state_dict(state_dict)[source][source]

从提供的 state_dict 恢复对象的状态。

参数

state_dict (Dict[str, Any]) – 要从中恢复的 state dict

state_dict()[source][source]

对象应将其 state_dict 表示形式作为字典返回。此函数的输出将被检查点,并在稍后的 load_state_dict() 中恢复。

警告

由于恢复检查点的就地性质,此函数也会在 torch.distributed.checkpoint.load 期间调用。

返回

对象的 state dict

返回类型

Dict

示例 展示了如何使用 Pytorch 分布式检查点来保存 FSDP 模型。

以下类型定义了检查点期间使用的 IO 接口

class torch.distributed.checkpoint.StorageReader[source][source]

load_state_dict 使用的从存储读取数据的接口。

一个 StorageReader 实例充当分布式检查点中的协调器和跟随者。作为初始化的一部分,每个实例都会被告知其角色。

子类应预期 load_state_dict 按以下顺序调用:

  1. (所有 rank)如果用户传递了有效的 checkpoint_id,则设置 checkpoint_id。

  2. (所有 rank)read_metadata()

  3. (所有 rank)set_up_storage_reader()

  4. (所有 rank)prepare_local_plan()

  5. (协调器)prepare_global_plan()

  6. (所有 rank)read_data()

abstract prepare_global_plan(plans)[source][source]

执行存储加载的集中式规划。

此方法仅在协调器实例上调用。

虽然此方法可以生成完全不同的计划,但首选方法是将存储特定的数据存储在 LoadPlan::storage_data 中。

参数

plans (List[LoadPlan]) – LoadPlan 实例的列表,每个 rank 一个。

返回

存储全局规划后转换的 LoadPlan 列表

返回类型

List[LoadPlan]

abstract prepare_local_plan(plan)[source][source]

执行特定于存储的本地规划。

虽然此方法可以生成完全不同的计划,但建议的方法是将存储特定的数据存储在 LoadPlan::storage_data 中。

参数

plan (LoadPlan) – 正在使用的 LoadPlan 中的本地计划。

返回

存储本地规划后转换的 LoadPlan

返回类型

LoadPlan

abstract read_data(plan, planner)[source][source]

使用 planner 解析数据,从 plan 中读取所有项。

子类应调用 LoadPlanner::load_bytes 以将 BytesIO 对象反序列化到正确的位置。

子类应调用 LoadPlanner::resolve_tensor 以访问应加载数据到的 tensors。

StorageLayer 负责正确调度所需的任何跨设备复制。

参数
  • plan (LoadPlan) – 要在其上执行的本地计划

  • planner (LoadPlanner) – 用于解析项的 planner 对象。

返回

一旦所有读取完成,将完成的 Future。

返回类型

Future[None]

abstract read_metadata()[source][source]

读取检查点元数据。

返回

与正在加载的检查点关联的元数据对象。

返回类型

元数据

abstract reset(checkpoint_id=None)[source][source]

调用以指示即将发生全新的检查点读取。如果用户为此检查点读取设置了 checkpoint_id,则可能会存在 checkpoint_id。checkpiont_id 的含义取决于存储。它可以是文件夹/文件的路径,也可以是键值存储的键。

参数

checkpoint_id (Union[str, os.PathLike, None]) – 此检查点实例的 ID。checkpoint_id 的含义取决于存储。它可以是文件夹或文件的路径。如果存储更像是键值存储,它也可以是一个键。(默认值:None

abstract set_up_storage_reader(metadata, is_coordinator)[source][source]

初始化此实例。

参数
  • metadata (Metadata) – 要使用的元数据模式。

  • is_coordinator (bool) – 此实例是否负责协调检查点。

abstract classmethod validate_checkpoint_id(checkpoint_id)[source][source]

检查给定的 checkpoint_id 是否受存储支持。这使我们能够启用自动存储选择。

返回类型

bool

class torch.distributed.checkpoint.StorageWriter[source][source]

save_state_dict 使用的接口,用于写入存储。

一个 StorageWriter 实例充当分布式检查点中的协调器和跟随者。作为初始化的一部分,会告知每个实例其角色。

子类应预期以下调用顺序。

  1. (所有 rank)如果用户传递了有效的 checkpoint_id,则设置 checkpoint_id。

  2. (所有 rank)set_up_storage_writer()

  3. (所有 rank)prepare_local_plan()

  4. (协调器)prepare_global_plan()

  5. (所有 rank)write_data()

  6. (协调器)finish()

abstract finish(metadata, results)[source][source]

写入元数据并将当前检查点标记为成功。

用于序列化 metadata 的实际格式/模式是一个实现细节。唯一的要求是它可以在相同的对象图中恢复。

参数
  • metadata (Metadata) – 新检查点的元数据

  • results (List[List[WriteResult]]) – 来自所有 rank 的 WriteResults 列表。

返回

返回类型

abstract prepare_global_plan(plans)[source][source]

执行存储的集中式规划。

此方法仅在协调器实例上调用。

虽然此方法可以生成完全不同的计划,但首选方法是将存储特定的数据存储在 SavePlan::storage_data 中。

参数

plans (List[SavePlan]) – SavePlan 实例的列表,每个 rank 一个。

返回

存储全局规划后转换的 SavePlan 列表

返回类型

List[SavePlan]

abstract prepare_local_plan(plan)[source][source]

执行特定于存储的本地规划。

虽然此方法可以生成完全不同的计划,但建议的方法是将存储特定的数据存储在 SavePlan::storage_data 中。

参数

plan (SavePlan) – 正在使用的 SavePlanner 中的本地计划。

返回

存储本地规划后转换的 SavePlan

返回类型

SavePlan

abstract reset(checkpoint_id=None)[source][source]

调用以指示即将发生全新的检查点写入。如果用户为此检查点写入设置了 checkpoint_id,则可能会存在 checkpoint_id。checkpiont_id 的含义取决于存储。它可以是文件夹/文件的路径,也可以是键值存储的键。

参数

checkpoint_id (Union[str, os.PathLike, None]) – 此检查点实例的 ID。 checkpoint_id 的含义取决于存储。它可以是文件夹或文件的路径。如果存储是键值存储,则它也可以是键。(默认值:None

abstract set_up_storage_writer(is_coordinator)[source][source]

初始化此实例。

参数

is_coordinator (bool) – 此实例是否负责协调检查点。

storage_meta()[source][source]

返回特定于存储的元数据。这用于在检查点中存储附加信息,这些信息对于提供请求级别的可观察性很有用。StorageMeta 在保存调用期间传递给 SavePlanner。默认返回 None。

TODO:提供一个示例

返回类型

Optional[StorageMeta]

abstract classmethod validate_checkpoint_id(checkpoint_id)[source][source]

检查给定的 checkpoint_id 是否受存储支持。这使我们能够启用自动存储选择。

返回类型

bool

abstract write_data(plan, planner)[source][source]

使用 planner 解析数据,从 plan 中写入所有项。

子类应在计划中的每个项上调用 SavePlanner::resolve_data,以访问要写入的底层对象。

子类应延迟调用 resolve_data,因为它可能会分配内存。对于 tensors,请做出以下假设

  • 它们可能在任何设备上,包括与 WriteItem::tensor_data 上的设备不匹配的设备

  • 它们可能是视图或不连续的。只需要保存投影。

参数
  • plan (SavePlan) – 要执行的保存计划。

  • planner (SavePlanner) – Planner 对象,用于将项解析为数据。

返回

完成时会生成 WriteResult 列表的 Future

返回类型

Future[List[WriteResult]]

以下类型定义了检查点期间使用的 planner 接口

class torch.distributed.checkpoint.LoadPlanner[source][source]

定义 load_state_dict 用于规划加载过程的协议的抽象类。

LoadPlanner 是有状态对象,可用于自定义整个加载过程。

LoadPlanner 充当 state_dict 的访问代理,因此对其进行的任何转换都将对整个过程可见。

planner 子类可以预期在 load_state_dict 期间进行以下调用顺序

  1. set_up_planner - 在所有 rank 上调用。

    表示检查点加载开始的信号。

  2. create_local_plan - 在所有 rank 上调用。

    处理 state_dict 并生成将发送以进行全局规划的 LoadPlan

  3. create_global_plan - 仅在协调器 rank 上调用。

    从所有 rank 获取 LoadPlan 并做出任何全局决策。

  4. load_bytes - 在每个 rank 上多次调用

    这在 state_dict 中每个非 tensor 值调用一次。

  5. resolve_tensor 和 commit_tensor - 在每个 rank 上多次调用

    它们对于 state_dict 中的每个 Tensor 值成对调用。

建议用户扩展 DefaultLoadPlanner 而不是直接扩展此接口,因为大多数更改可以通过单个方法中的更改来表达。

通常有两种扩展模式

重写 state_dict。这是扩展加载过程的最简单方法,因为它不需要理解 LoadPlan 如何工作的复杂性。我们需要保留对原始 state_dict 的引用,因为加载是就地发生的,因此我们需要能够就地执行它

>>> class RenamePlanner(DefaultLoadPlanner):
>>>     def set_up_planner(
>>>         self,
>>>         state_dict: STATE_DICT_TYPE,
>>>         metadata: Metadata,
>>>         is_coordinator: bool,
>>>     ) -> None:
>>>         self.original_state_dict = state_dict
>>>         state_dict = {"foo_" + k: v for k, v in state_dict.items()}
>>>
>>>         if self.flatten_sharded_tensors:
>>>             state_dict = _flatten_sharded_tensors(state_dict)
>>>
>>>         if self.flatten_state_dict:
>>>             state_dict, self.mappings = flatten_state_dict(state_dict)
>>>
>>>         self.state_dict = state_dict
>>>         self.metadata = metadata
>>>         self.is_coordinator = is_coordinator
>>>
>>>     def load_bytes(self, read_item, value):
>>>         # Remove the "foo_" prefix
>>>         self.original_state_dict[read_item.dest_index.fqn[4:]] = torch.load(value, weights_only=False)

修改 resolve_tensor 和 commit_tensor 以处理加载时转换。

>>> class MetaModelMaterialize(DefaultSavePlanner):
>>>     def resolve_tensor(self, read_item):
>>>         tensor = super().resolve_tensor(read_item)
>>>         return torch.empty_like(tensor, device="cpu")
>>>
>>>     def commit_tensor(self, read_item, tensor):
>>>         self.state_dict[read_item.dest_index.fqn] = tensor
abstract commit_tensor(read_item, tensor)[source][source]

在 StorageReader 完成将数据加载到 tensor 后调用一次。

提供的 tensor 与调用 resolve_tensor 返回的 tensor 相同。仅当此 LoadPlanner 需要在将 tensor 复制回 state_dict 中的 tensor 之前对其进行后处理时,才需要此方法。

tensor 的内容将遵循其设备同步模型。

abstract create_global_plan(global_plan)[source][source]

计算全局加载计划并返回每个 rank 的计划。

. 注意:这仅在协调器 rank 上调用

返回类型

List[LoadPlan]

abstract create_local_plan()[source][source]

根据 state_dict 和 set_up_planner 提供的元数据创建 LoadPlan。

. 注意:这在每个 rank 上调用。

返回类型

LoadPlan

abstract finish_plan(central_plan)[source][source]

接受来自协调器的计划并返回最终的 LoadPlan。

返回类型

LoadPlan

abstract load_bytes(read_item, value)[source][source]

加载由 read_item``和 ``value 描述的项。

此方法应就地修改底层 state_dict。

value 的内容由用于生成正在加载的检查点的 SavePlanner 定义。

resolve_bytes(read_item)[source][source]

返回 StorageReader 用于加载 read_item 的 BytesIO。

BytesIO 应与底层 state_dict 上的 BytesIO 别名,因为 StorageReader 将替换其内容。

返回类型

BytesIO

abstract resolve_tensor(read_item)[source][source]

返回由 read_item 描述的 tensor,StorageReader 将使用它来加载 read_item

tensor 应与底层 state_dict 上的 tensor 别名,因为 StorageReader 将替换其内容。如果由于任何原因无法实现,planner 可以使用 commit_tensor 方法将数据复制回 state_dict 中的 tensor。

返回类型

Tensor

abstract set_up_planner(state_dict, metadata=None, is_coordinator=False)[source][source]

初始化此实例以将数据加载到 state_dict 中。

. 注意:这在每个 rank 上调用。

class torch.distributed.checkpoint.LoadPlan(items: List[torch.distributed.checkpoint.planner.ReadItem], storage_data: Any = None, planner_data: Any = None)[source][source]
class torch.distributed.checkpoint.ReadItem(type: torch.distributed.checkpoint.planner.LoadItemType, dest_index: torch.distributed.checkpoint.metadata.MetadataIndex, dest_offsets: torch.Size, storage_index: torch.distributed.checkpoint.metadata.MetadataIndex, storage_offsets: torch.Size, lengths: torch.Size)[source][source]
class torch.distributed.checkpoint.SavePlanner[source][source]

定义用于 save_state_dict 规划保存过程的抽象类。

SavePlanner 是有状态的对象,可用于自定义整个保存过程。

SavePlanner 充当 state_dict 的访问代理,因此对其进行的任何转换都将对整个过程可见。

planner 子类可以预期在 save_state_dict 期间进行以下调用序列

  1. set_up_planner - 在所有 rank 上调用。

    表示检查点保存的开始。

  2. create_local_plan - 在所有 rank 上调用。

    处理 state_dict 并生成将发送以进行全局规划的 SavePlan

  3. create_global_plan - 仅在协调器 rank 上调用。

    从所有 rank 获取 SavePlan 并做出任何全局决策。

  4. finish_plan - 在所有 rank 上调用。

    这使每个 rank 都有机会适应全局规划决策。

  5. resolve_data - 在每个 rank 上多次调用

    state_dict 上查找值以供存储层写入。

建议用户扩展 DefaultSavePlanner 而不是直接扩展此接口,因为大多数更改可以通过单个方法中的更改来表达。

通常有 3 种扩展模式

重写 state_dict。这是扩展保存过程的最简单方法,因为它不需要理解 SavePlan 如何工作的复杂性

>>> class RenamePlanner(DefaultSavePlanner):
>>>     def set_up_planner(
>>>         self,
>>>         state_dict: STATE_DICT_TYPE,
>>>         storage_meta: Optional[StorageMeta],
>>>         is_coordinator: bool,
>>>     ) -> None:
>>>         # prefix all keys with `foo_``
>>>         super().set_up_planner({"foo_" + k: v for k, v in state_dict.items()}, storage_meta, is_coordinator)

协同修改本地计划和查找。当需要精细控制数据持久化方式时,这很有用

>>> class FP16Planner(DefaultSavePlanner):
>>>     def create_local_plan(self):
>>>         plan = super().create_local_plan()
>>>         for p in plan:
>>>             if p.tensor_data is not None:
>>>                 p.tensor_data.properties.dtype = torch.float16
>>>         return plan
>>>
>>>     def resolve_data(self, write_item):
>>>         item = super().resolve_data(write_item)
>>>         return item if write_item.type == WriteItemType.BYTE_IO else item.to(torch.float16)

使用全局规划步骤来做出中央决策,这些决策无法由每个 rank 单独做出

>>> from itertools import zip_longest
>>> from dataclasses import replace
>>> class DDPLoadBalancingPlanner(DefaultSavePlanner):
>>>     # This uses the default local plan behavior of having all non-sharded writes in rank 0
>>>     # This sample doesn't handle ShardedTensors
>>>     def create_global_plan(self, all_plans):
>>>         iters = [iter(all_plans[0].items)] * len(all_plans)
>>>         items_per_rank = [
>>>             [item for item in items if item is not None]
>>>             for items in zip(*zip_longest(*iters), strict=True)
>>>         ]
>>>         all_plans = [
>>>             replace(plan, items=items)
>>>             for plan, items in zip(all_plans, items_per_rank, strict=True)
>>>         ]
>>>         return super().create_global_plan(all_plans)

最后,一些 planner 需要在检查点中保存额外的元数据,这可以通过让每个 rank 在本地计划中贡献其数据项,并由全局 planner 聚合它们来实现

>>> class SaveExtraDataPlanner(DefaultSavePlanner):
>>>     def create_local_plan(self) -> SavePlan:
>>>         plan = super().create_local_plan()
>>>         return replace(plan, planner_data="per-rank-data")
>>>
>>>     def create_global_plan(self, all_plans: List[SavePlan]) -> Tuple[List[SavePlan], Metadata]:
>>>         global_plan, metadata = super().create_global_plan(all_plans)
>>>         merged_data = [p.planner_data for p in global_plan]
>>>         metadata = replace(metadata, planner_data=merged_data)
>>>         return global_plan, metadata
abstract create_global_plan(all_plans)[source][source]

计算全局检查点计划并返回每个 rank 的本地计划。

这仅在协调器 rank 上调用。

返回类型

Tuple[List[SavePlan], Metadata]

abstract create_local_plan()[source][source]

计算当前 rank 的保存计划。

这将进行聚合并传递给 create_global_plan。Planner 特定数据可以通过 SavePlan::planner_data 传递。

这在所有 rank 上调用。

返回类型

SavePlan

abstract finish_plan(new_plan)[source][source]

合并由 create_local_plan 创建的计划和 create_global_plan 的结果。

这在所有 rank 上调用。

返回类型

SavePlan

abstract resolve_data(write_item)[source][source]

转换并准备来自 state_dictwrite_item 以进行存储,确保幂等性和线程安全性。

state_dict 中查找与 write_item 关联的对象,并在存储层使用它之前应用任何转换(例如序列化)。

在每个 rank 上多次调用,每个最终 SavePlan 中的 WriteItem 至少调用一次。

此方法应该是幂等且线程安全的。StorageWriter 实现可以根据需要频繁调用它。

任何分配内存的转换都应在此方法被调用时延迟完成,以减少检查点所需的峰值内存。

当返回张量时,它们可以在任何设备或格式上,它们也可以是视图。存储层负责弄清楚如何保存它们。

返回类型

Union[Tensor, BytesIO]

abstract set_up_planner(state_dict, storage_meta=None, is_coordinator=False)[source][source]

初始化此 planner 以保存 state_dict

实现应保存这些值,因为它们不会在保存过程中稍后提供。

这在所有 rank 上调用。

class torch.distributed.checkpoint.SavePlan(items: List[torch.distributed.checkpoint.planner.WriteItem], storage_data: Any = None, planner_data: Any = None)[source][source]
class torch.distributed.checkpoint.planner.WriteItem(index, type, tensor_data=None)[source][source]

Dataclass,用于保存有关需要写入存储的信息。

tensor_storage_size()[source][source]

计算底层张量的存储大小,如果这不是张量写入,则返回 None。

返回

Optional[int] 存储大小,以字节为单位表示底层张量(如果有)。

返回类型

Optional[int]

我们提供了一个基于文件系统的存储层

class torch.distributed.checkpoint.FileSystemReader(path)[source][source]
property checkpoint_id: Union[str, PathLike]

返回将用于加载检查点的 checkpoint_id。

class torch.distributed.checkpoint.FileSystemWriter(path, single_file_per_rank=True, sync_files=True, thread_count=1, per_thread_copy_ahead=10000000, cache_staged_state_dict=False, overwrite=True)[source][source]

使用文件 IO 的 StorageWriter 的基本实现。

此实现做出以下假设和简化

  • 检查点路径是一个空目录或非现有目录。

  • 文件创建是原子性的

检查点由每个写入请求一个文件以及包含序列化元数据的 .metadata 文件组成。

stage(state_dict)[source][source]

AsyncStager.stage 的重写

返回类型

Dict[str, Union[StatefulT, Any]]

我们提供了 LoadPlannerSavePlanner 的默认实现,可以处理所有 torch.distributed 构造,例如 FSDP、DDP、ShardedTensor 和 DistributedTensor。

class torch.distributed.checkpoint.DefaultSavePlanner(flatten_state_dict=True, flatten_sharded_tensors=True, dedup_replicated_tensors=None, dedup_save_to_lowest_rank=False)[source][source]
lookup_object(index)[source][source]

从 planner 接口的扩展,使其易于扩展默认 planner。

返回类型

任何

transform_object(write_item, object)[source][source]

从 planner 接口的扩展,使其易于扩展默认 planner。

class torch.distributed.checkpoint.DefaultLoadPlanner(flatten_state_dict=True, flatten_sharded_tensors=True, allow_partial_load=False)[source][source]

DefaultLoadPlanner,在 LoadPlanner 的基础上添加了多项功能。

特别是,它添加了以下内容

flatten_state_dict:处理具有嵌套 dict 的 state_dict flatten_sharded_tensors:对于 2D 并行模式下的 FSDP allow_partial_load:如果为 False,则当 state_dict 中存在密钥但检查点中不存在时,将引发运行时错误。

lookup_tensor(index)[source][source]

从 planner 接口的扩展,使其易于扩展默认 planner。

返回类型

Tensor

transform_tensor(read_item, tensor)[source][source]

从 planner 接口的扩展,使其易于扩展默认 planner。

由于遗留的设计决策,即使原始的非并行化模型相同,FSDPDDP 的状态字典也可能具有不同的键或完全限定名称(例如,layer1.weight)。此外,FSDP 提供了各种类型的模型状态字典,例如完整状态字典和分片状态字典。此外,优化器状态字典使用参数 ID 而不是完全限定名称来标识参数,当使用并行性(例如,流水线并行)时,可能会导致问题。

为了应对这些挑战,我们提供了一系列 API,供用户轻松管理 state_dicts。get_model_state_dict 返回的模型状态字典的键与非并行化模型状态字典返回的键一致。同样,get_optimizer_state_dict 提供的优化器状态字典的键在所有应用的并行性中都是统一的。为了实现这种一致性,get_optimizer_state_dict 会将参数 ID 转换为与非并行化模型状态字典中找到的参数 ID 相同的完全限定名称。

请注意,这些 API 返回的结果可以直接与 torch.distributed.checkpoint.save()torch.distributed.checkpoint.load() 方法一起使用,而无需任何额外的转换。

请注意,此功能是实验性的,API 签名将来可能会更改。

torch.distributed.checkpoint.state_dict.get_state_dict(model, optimizers, *, submodules=None, options=None)[source][source]

返回模型 state_dict 和优化器 state_dict。

get_state_dict 可以处理通过 PyTorch FSDP/fully_shard、DDP/replicate、tensor_parallel/parallelize_module 以及这些并行性的任意组合进行并行的任何模块。get_state_dict 的主要功能是:1.) 返回可以重新分片到不同数量的训练器和/或不同并行性的模型和优化器 state_dict。2.) 隐藏特定于并行性的 state_dict API。用户不必调用这些 API。3.) 健全性检查结果 state_dict。

结果状态字典的键是规范 FQN(完全限定名称)。规范 FQN 是指基于参数在 nn.Module 层次结构中的位置的 FQN。更具体地说,参数的规范 FQN 是指当模块未通过任何并行性分发时,module.named_parameters()module.named_buffers() 返回的 FQN。由于优化器在内部使用参数 ID 来表示参数,因此在调用此 API 时,将从参数 ID 转换为规范 FQN。

get_state_dict 也可以处理未并行的模块。在这种情况下,get_state_dict 仅执行一个功能 – 将优化器参数 ID 转换为规范 FQN。

示例

>>> import torch
>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
>>> from torch.nn.parallel import DistributedDataParallel as DDP
>>> from torch.distributed.checkpoint.state_dict import get_state_dict
>>> fsdp_model = FSDP(copy.deepcopy(model))
>>> fsdp_optim = torch.optim.Adam(model.parameters(), lr=1e-3)
>>> ddp_model = DDP(copy.deepcopy(model))
>>> ddp_optim = torch.optim.Adam(model.parameters(), lr=1e-3)
>>> ddp_state_dict, ddp_optim_state_dict = get_state_dict(ddp_model, ddp_optim)
>>> fsdp_state_dict, fsdp_optim_state_dict = get_state_dict(fsdp_model, fsdp_optim)
>>> # if we simply call ddp_model.state_dict() and fsdp_model.state_dict(),
>>> # the asserts will fail.
>>> assert ddp_state_dict == fsdp_state_dict
>>> assert ddp_optim_state == fsdp_optim_state_dict
参数
  • model (nn.Module) – 要作为模型的 nn.Module。

  • optimizers (Union[None, Optimizer, Iterable[Optimizer]]) – 用于优化 model 的优化器。

  • submodules (deprecated) – Optional[Set[nn.Module]]: 仅返回属于子模块的模型参数。

  • options (StateDictOptions) – 用于控制应如何返回模型 state_dict 和优化器 state_dict 的选项。有关详细信息,请参阅 StateDictOptions

返回

Tuple,其中包含模型 state_dict 和优化器 state_dict。

返回类型

Tuple[Dict[str, ValueType], OptimizerStateType]

torch.distributed.checkpoint.state_dict.get_model_state_dict(model, *, submodules=None, options=None)[source][source]

返回 model 的模型 state_dict。

有关详细用法,请参阅 get_state_dict

参数
  • model (nn.Module) – 要作为模型的 nn.Module。

  • submodules (deprecated) – Optional[Set[nn.Module]]: 仅返回属于子模块的模型参数。

  • options (StateDictOptions) – 用于控制应如何返回模型 state_dict 和优化器 state_dict 的选项。有关详细信息,请参阅 StateDictOptions

返回

model 的 state_dict。

返回类型

Dict[str, ValueType]

torch.distributed.checkpoint.state_dict.get_optimizer_state_dict(model, optimizers, *, submodules=None, options=None)[source][source]

返回优化器的组合 state_dict。

有关详细用法,请参阅 get_state_dict

参数
  • model (nn.Module) – 要作为模型的 nn.Module。

  • optimizers (Union[None, Optimizer, Iterable[Optimizer]]) – 用于优化 model 的优化器。

  • submodules (deprecated) – Optional[Set[nn.Module]]: 仅返回属于子模块的模型参数。

  • options (StateDictOptions) – 用于控制应如何返回模型 state_dict 和优化器 state_dict 的选项。有关详细信息,请参阅 StateDictOptions

返回

optimizers 的 state_dict。

返回类型

OptimizerStateType

torch.distributed.checkpoint.state_dict.set_state_dict(model, optimizers, *, model_state_dict, optim_state_dict, options=None)[source][source]

加载模型 state_dict 和优化器 state_dict。

get_state_dict 的对应项,用于将 state_dict 设置到模型和优化器。给定的 model_state_dictoptim_state_dict 不必由 get_state_dict 返回,但必须满足以下要求:1) 所有 FQN 都是 get_state_dict 中定义的规范 FQN,2) 如果张量是分片的,则它必须是 ShardedTensor 或 DTensor,3) 优化器 state_dict 不能包含参数 ID;键应为规范 FQN。

参数
  • model (nn.Module) – 要作为模型的 nn.Module。

  • optimizers (Union[Optimizer, Iterable[Optimizer]]) – 用于优化 model 的优化器。

  • model_state_dict (Dict[str, ValueType]) – (Union[Dict[nn.Module, Dict[str, ValueType]], Dict[str, ValueType]]): 要加载的模型 state_dict。如果 model_state_dict 的键是 nn.Module,则该键是 model 的子模块,并且值应为子模块的 state_dict。加载 state_dict 时,子模块的前缀将附加到 state_dict。

  • optim_state_dict (OptimizerStateType) – OptimizerStateType: 要加载的优化器 state_dict。

  • options (StateDictOptions) – 用于控制应如何加载模型 state_dict 和优化器 state_dict 的选项。有关详细信息,请参阅 StateDictOptions

返回

  • missing_keys 是一个字符串列表,其中包含模型 state_dict 的缺失键。

  • unexpected_keys 是一个字符串列表,其中包含模型 state_dict 的意外键。

返回类型

NamedTuple,包含 missing_keysunexpected_keys 字段

torch.distributed.checkpoint.state_dict.set_model_state_dict(model, model_state_dict, *, options=None)[source][source]

加载模型 state_dict。

get_model_state_dict 的对应项,用于将 state_dict 设置到模型。有关详细用法,请参阅 set_state_dict

参数
  • model (nn.Module) – 要作为模型的 nn.Module。

  • model_state_dict (Dict[str, ValueType]) – (Dict[str, ValueType]): 要加载的模型 state_dict。如果 model_state_dict 的键是 nn.Module,则该键是 model 的子模块,并且值应为子模块的 state_dict。加载 state_dict 时,子模块的前缀将附加到 state_dict。

  • options (StateDictOptions) – 用于控制应如何加载模型 state_dict 和优化器 state_dict 的选项。有关详细信息,请参阅 StateDictOptions

返回

  • missing_keys 是一个字符串列表,其中包含缺失键

  • unexpected_keys 是一个字符串列表,其中包含意外键

返回类型

NamedTuple,包含 missing_keysunexpected_keys 字段

torch.distributed.checkpoint.state_dict.set_optimizer_state_dict(model, optimizers, optim_state_dict, *, options=None)[source][source]

加载优化器 state_dict。

get_optimizer_state_dict 的对应项,用于将 state_dict 设置到优化器。有关详细用法,请参阅 set_state_dict

参数
  • model (nn.Module) – 要作为模型的 nn.Module。

  • optimizers (Union[Optimizer, Iterable[Optimizer]]) – 用于优化 model 的优化器。

  • optim_state_dict (OptimizerStateType) – OptimizerStateType: 要加载的优化器 state_dict。

  • options (StateDictOptions) – 用于控制应如何加载模型 state_dict 和优化器 state_dict 的选项。有关详细信息,请参阅 StateDictOptions

返回

返回类型

class torch.distributed.checkpoint.state_dict.StateDictOptions(full_state_dict=False, cpu_offload=False, ignore_frozen_params=False, keep_submodule_prefixes=True, strict=True, broadcast_from_rank0=False, flatten_optimizer_state_dict=False)[source][source]

此数据类指定了 get_state_dict/set_state_dict 的工作方式。

  • full_state_dict: 如果设置为 True,则返回的 state_dict 中的所有张量都将被收集。返回的 state_dict 中将不会有 ShardedTensor 和 DTensor。

  • cpu_offload: 将所有张量卸载到 cpu。为了防止 CPU OOM,如果 full_state_dict 也为 true,则只有 rank0 会获得 state_dict,而所有其他 rank 将获得空的 state_dict。

  • ignore_frozen_params: 如果值为 True,则返回的 state_dict 将不包含任何冻结的参数 – 即 requires_grad 为 False 的参数。默认值为 False。

  • keep_submodule_prefixes (已弃用): 当 submodules 不为 None 时,此选项指示是否保留 state_dict 键中的子模块前缀。例如,如果子模块是 module.pretrain,并且参数的完整 FQN 是 pretrain.layer1.weight。当此选项为 True 时,返回的 state_dict 中参数的键将为 pretrain.layer1.weight。如果此选项为 False,则键将为 layer1.weight。请注意,如果 keep_submodule_prefixes 为 False,则可能存在冲突的 FQN,因此 submodules 中应该只有一个子模块。

  • strict: 当 set_state_dict 调用 model.load_state_dict() 时的 strict 选项。

  • broadcast_from_rank0: 当此选项为 True 时,rank0 应该接收一个

    完整的 state_dict,并将 state_dict/ optim_state_dict 中的张量逐个广播到其他 rank。其他 rank 将接收张量,并根据模型和优化器中的本地分片进行分片。full_state_dict 在使用此选项时必须设置为 True。此选项目前仅支持 DTensor,不支持旧版的 ShardedTensor。

对于习惯于使用和共享 torch.save 格式模型的用户,我们提供了以下方法,这些方法提供了离线实用程序,用于在格式之间进行转换。

torch.distributed.checkpoint.format_utils.dcp_to_torch_save(dcp_checkpoint_dir, torch_save_path)[source][source]

给定一个包含 DCP 检查点(checkpoint)的目录,此函数会将其转换为 Torch save 文件。

参数
  • dcp_checkpoint_dir (Union[str, PathLike]) – 包含 DCP 检查点(checkpoint)的目录。

  • torch_save_path (Union[str, PathLike]) – 用于存储转换后的 Torch save 文件的文件名。

警告

为了避免 OOM,建议仅在单个 rank 上运行此函数。

torch.distributed.checkpoint.format_utils.torch_save_to_dcp(torch_save_path, dcp_checkpoint_dir)[source][source]

给定 Torch save 文件的位置,将其转换为 DCP 检查点(checkpoint)。

参数
  • torch_save_path (Union[str, PathLike]) – Torch save 文件的文件名。

  • dcp_checkpoint_dir (Union[str, PathLike]) – 用于存储 DCP 检查点(checkpoint)的目录。

警告

为了避免 OOM,建议仅在单个 rank 上运行此函数。

以下类也可用于从 torch.save 格式在线加载和重新分片模型。

class torch.distributed.checkpoint.format_utils.BroadcastingTorchSaveReader(checkpoint_id=None, coordinator_rank=0)[source][source]

用于读取 Torch Save 文件的 StorageReader。此读取器将在协调器 rank 上读取整个检查点(checkpoint),然后将每个张量广播和分片到所有 rank。

. 注意: 旨在与 DynamicMetaLoadPlanner 一起使用

警告

当前实现仅支持加载张量。

>>> sd = {"mode": model}
>>> dcp.load(
>>>    sd,
>>>    storage_reader=BroadcastingTorchSaveReader(),
>>>    planner=DynamicMetaLoadPlanner(),
>>>    checkpoint_id="path_to_model.pt"
>>> )
prepare_global_plan(global_plan)[source][source]

StorageReader 方法的实现

返回类型

List[LoadPlan]

prepare_local_plan(plan)[source][source]

StorageReader 方法的实现

返回类型

LoadPlan

read_data(plan, planner)[source][source]

在协调器 rank 上读取 torch save 数据,然后进行广播,这会产生通信成本,但避免了在每个 rank 上加载整个检查点(checkpoint),有望防止 OOM 问题

返回类型

Future[None]

read_metadata()[source][source]

扩展默认 StorageReader 以支持构建元数据文件

返回类型

元数据

reset(checkpoint_id=None)[source][source]

StorageReader 方法的实现

set_up_storage_reader(metadata, is_coordinator)[source][source]

StorageReader 方法的实现

classmethod validate_checkpoint_id(checkpoint_id)[source][source]

StorageReader 方法的实现

返回类型

bool

class torch.distributed.checkpoint.format_utils.DynamicMetaLoadPlanner(flatten_state_dict=True, flatten_sharded_tensors=True, allow_partial_load=False)[source][source]

DefaultLoadPlanner 的扩展,它基于传入的 state dict 创建一个新的 Metadata 对象,从而避免了从磁盘读取元数据的需要。这在读取没有元数据文件的格式(如 Torch Save 文件)时非常有用。

. 注意: 旨在与 BroadcastingTorchSaveReader 一起使用

警告

当前实现仅支持加载张量。

>>> sd = {"mode": model}
>>> dcp.load(
>>>    sd,
>>>    storage_reader=BroadcastingTorchSaveReader(),
>>>    planner=DynamicMetaLoadPlanner(),
>>>    checkpoint_id="path_to_model.pt"
>>> )
set_up_planner(state_dict, metadata=None, is_coordinator=False)[source][source]

规划器的设置,通过从 state dict 创建 Metadata 对象来扩展默认行为

以下实验性接口旨在提高生产环境中的可观察性

文档

访问 PyTorch 的综合开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获取您的问题解答

查看资源