分布式检查点 - torch.distributed.checkpoint¶
分布式检查点(DCP)支持从多个进程(rank)并行加载和保存模型。它处理加载时的重新分片(resharding),从而可以在一种集群拓扑中保存,并在另一种集群拓扑中加载。
DCP 与 torch.save 和 torch.load 在几个重要方面有所不同
每个检查点会产生多个文件,每个进程(rank)至少一个。
它以原地(in place)方式操作,这意味着模型应首先分配其数据,然后 DCP 会使用该存储空间。
加载和保存检查点的入口点如下
附加资源:¶
- class torch.distributed.checkpoint.state_dict_saver.AsyncCheckpointerType(value)[source][source]¶
异步检查点类型枚举。
- torch.distributed.checkpoint.state_dict_saver.save(state_dict, *, checkpoint_id=None, storage_writer=None, planner=None, process_group=None, no_dist=False)[source][source]¶
以 SPMD 风格保存分布式模型。
此函数与
torch.save()
不同,因为它通过让每个进程(rank)只保存其本地分片来处理ShardedTensor
和DTensor
。对于每个
Stateful
对象(具有state_dict
和load_state_dict
),save 在序列化之前会调用state_dict
。警告
保存的 state_dict 不保证在不同 PyTorch 版本之间向后兼容。
警告
如果使用 process_group 参数,请确保只有属于该进程组的进程(rank)调用 save_state_dict,并且 state_dict 中的所有数据都属于该进程组。
注意
为 FSDP 的 ShardingStrategy.HYBRID_SHARD 保存检查点时,只有分片组(shard_group)中的一个应调用 save_state_dict,并且需要传入相应的进程组。
注意
- 如果没有可用的进程组,此函数假定目的是保存
state_dict 到本地进程。
- 参数
state_dict (Dict[str, Any]) – 要保存的 state_dict。
checkpoint_id (Union[str, os.PathLike, None]) – 此检查点实例的 ID。checkpoint_id 的含义取决于存储后端。它可以是文件夹或文件的路径,如果存储是键值存储,它也可以是键。(默认值:
None
)storage_writer (Optional[StorageWriter]) – 用于执行写入的 StorageWriter 实例。如果未指定此参数,DCP 将根据 checkpoint_id 自动推断写入器。如果 checkpoint_id 也为 None,则会抛出异常。(默认值:
None
)planner (Optional[SavePlanner]) – SavePlanner 实例。如果未指定此参数,将使用默认规划器。(默认值:
None
)process_group (Optional[ProcessGroup]) – 用于跨进程(rank)同步的 ProcessGroup。(默认值:
None
)no_dist (bool) – 如果为
True
,此函数将假定意图是不使用跨进程(rank)同步来加载检查点。(默认值:False
)
- 返回
保存的检查点的元数据对象。
- 返回类型
Metadata
示例
>>> my_model = MyModule()
>>> state_dict = {"model": my_model}
>>> fs_storage_writer = torch.distributed.checkpoint.FileSystemWriter( ... "/checkpoint/1" ... ) >>> torch.distributed.checkpoint.save( >>> state_dict=state_dict, >>> storage_writer=fs_storage_writer, >>> )
注意
save_state_dict 使用集合通信(collectives)来协调跨进程(rank)的写入。对于基于 NCCL 的进程组,对象内部的张量表示必须在通信发生之前移动到 GPU 设备上。在这种情况下,使用的设备由
torch.cuda.current_device()
给出,用户有责任确保此设置正确,以便每个进程(rank)都有一个独立的 GPU,通过调用torch.cuda.set_device()
。
- torch.distributed.checkpoint.state_dict_saver.async_save(state_dict, *, checkpoint_id=None, storage_writer=None, planner=None, process_group=None, async_checkpointer_type=AsyncCheckpointerType.THREAD)[source][source]¶
save
的异步版本。此代码首先将 state_dict 卸载到暂存存储(默认为 CPU 内存),然后在单独的线程中调用 save。警告
此功能是实验性的,可能会发生变化。
- 参数
state_dict (Dict[str, Any]) – 要保存的 state_dict。
checkpoint_id (Union[str, os.PathLike, None]) – 此检查点实例的 ID。checkpoint_id 的含义取决于存储后端。它可以是文件夹或文件的路径,如果存储是键值存储,它也可以是键。(默认值:
None
)storage_writer (Optional[StorageWriter]) – 用于执行“暂存”(stage)和“保存”(save)的 StorageWriter 实例。如果未指定此参数,DCP 将根据 checkpoint_id 自动推断写入器。如果 checkpoint_id 也为 None,则会抛出异常。(默认值:
None
)planner (Optional[SavePlanner]) – SavePlanner 实例。如果未指定此参数,将使用默认规划器。(默认值:
None
)process_group (Optional[ProcessGroup]) – 用于跨进程(rank)同步的 ProcessGroup。(默认值:
None
)
- 返回
一个 Future 对象,包含 save 生成的 Metadata 对象。
- 返回类型
示例
>>> my_model = MyModule()
>>> state_dict = {"model": my_model}
>>> fs_storage_writer = torch.distributed.checkpoint.FileSystemWriter( ... "/checkpoint/1" ... ) >>> checkpoint_future = torch.distributed.checkpoint.async_save( >>> state_dict=state_dict, >>> storage_writer=fs_storage_writer, >>> ) >>> >>> # ... do some work ... >>> >>> checkpoint_future.result()
- torch.distributed.checkpoint.state_dict_saver.save_state_dict(state_dict, storage_writer, process_group=None, coordinator_rank=0, no_dist=False, planner=None)[source][source]¶
此方法已弃用。请改用 ‘save’。
- 返回类型
Metadata
- torch.distributed.checkpoint.state_dict_loader.load(state_dict, *, checkpoint_id=None, storage_reader=None, planner=None, process_group=None, no_dist=False)[source][source]¶
以 SPMD 风格将检查点加载到分布式 state dict 中。
提供给此 API 的每个进程(rank)的
state_dict
必须包含相同的键。键不匹配可能导致程序挂起或错误。如果不确定,可以使用utils._assert_same_keys
API 进行检查(但这可能会产生通信开销)。每个进程(rank)都会尝试读取满足所需 state_dict 所需的最少量数据。加载
ShardedTensor
或DTensor
实例时,每个进程(rank)仅读取其本地分片的数据。对于每个
Stateful
对象(具有state_dict
和load_state_dict
),load 在尝试反序列化之前会首先调用state_dict
,反序列化完成后再调用load_state_dict
。对于每个非Stateful
对象,load 将反序列化该对象,然后在state_dict
中用反序列化后的对象替换它。警告
调用此函数 *之前*,
state_dict
中的所有张量必须已在其目标设备上分配。所有非张量数据都使用 torch.load() 加载,并在 state_dict 上原地修改。
警告
用户必须在根模块上调用 load_state_dict 以确保加载后处理和非张量数据正确传播。
- 参数
state_dict (Dict[str, Any]) – 要加载检查点到其中的 state_dict。
checkpoint_id (Union[str, os.PathLike, None]) – 此检查点实例的 ID。checkpoint_id 的含义取决于存储后端。它可以是文件夹或文件的路径,如果存储是键值存储,它也可以是键。(默认值:
None
)storage_reader (Optional[StorageReader]) – 用于执行读取的 StorageReader 实例。如果未指定此参数,DCP 将根据 checkpoint_id 自动推断读取器。如果 checkpoint_id 也为 None,则会抛出异常。(默认值:
None
)planner (Optional[LoadPlanner]) – LoadPlanner 实例。如果未指定此参数,将使用默认规划器。(默认值:
None
)process_group (Optional[ProcessGroup]) – 用于跨进程(rank)同步的 ProcessGroup。(默认值:
None
)no_dist (bool) – 如果为
True
,此函数将假定意图是不使用跨进程(rank)同步来加载检查点。(默认值:False
)
- 返回
无。
- 返回类型
无
- 示例
>>> my_model = MyModule() >>> optimizer = Adagrad(my_model.parameters()) >>> model_state_dict = my_model.state_dict() >>> fs_storage_reader = torch.distributed.checkpoint.FileSystemReader( ... "/checkpoint/1" ... )
>>> torch.distributed.checkpoint.load_state_dict( >>> state_dict=model_state_dict, >>> storage_reader=fs_storage_reader, >>> )
>>> # module.load_state_dict() function might have customized steps >>> # to flush the state_dict, must call it to >>> # ensure correct behavior. >>> my_model.load_state_dict(model_state_dict)
注意
load_state_dict 使用集合通信(collectives)来协调跨进程(rank)的读取。对于基于 NCCL 的进程组,对象内部的张量表示必须在通信发生之前移动到 GPU 设备上。在这种情况下,使用的设备由
torch.cuda.current_device()
给出,用户有责任确保此设置正确,以便每个进程(rank)都有一个独立的 GPU,通过调用torch.cuda.set_device()
。
- torch.distributed.checkpoint.state_dict_loader.load_state_dict(state_dict, storage_reader, process_group=None, coordinator_rank=0, no_dist=False, planner=None)[source][source]¶
此方法已弃用。请改用 ‘load’。
以下模块对于异步检查点(torch.distributed.checkpoint.async_save)中使用的暂存(staging)机制进行额外的自定义也很有用
- class torch.distributed.checkpoint.staging.AsyncStager(*args, **kwargs)[source][source]¶
此协议旨在为 dcp.async_save 提供自定义和可扩展性,允许用户在并行执行常规 dcp.save 路径之前自定义数据暂存(staged)的方式。预期的操作顺序(具体定义在 torch.distributed.state_dict_saver.async_save 中)如下:
- AsyncStager.stage(state_dict)
此调用使 AsyncStager 有机会对 state_dict 进行“暂存”(stage)。在此上下文中,暂存的预期和目的是创建一个“训练安全”(training-safe)的 state dict 表示,这意味着在暂存完成后对模块数据进行的任何更新都不应反映在此方法返回的 state dict 中。例如,在默认情况下,会在 CPU RAM 上创建整个 state dict 的副本并在此返回,从而允许用户在不冒正在序列化的数据被更改的风险的情况下继续训练。
- 在 stage 返回的 state_dict 上并行调用 dcp.save。此调用负责
序列化 state_dict 并将其写入存储。
- 如果 AsyncStager.should_synchronize_after_execute 为 True,此方法将立即在
序列化线程启动后、dcp.async_save 返回之前被调用。如果设置为 False,则假定用户已定义了一个自定义同步点,以进一步优化训练循环中的保存延迟(例如,通过将暂存与前向/后向传播重叠),用户有责任在适当的时间调用 AsyncStager.synchronize_staging。
- class torch.distributed.checkpoint.staging.BlockingAsyncStager(cache_staged_state_dict=False, type_check=False)[source][source]¶
AsyncStager 的一个实现,它将 state_dict 暂存到 CPU RAM 并阻塞直到复制完成。此实现还提供了一个选项,可以使用锁定内存(pinned memory)优化暂存延迟。
注意:在这种情况下,synchronize_staging 是一个空操作(no-op)。
除了上述入口点之外,如下所述的Stateful对象在保存/加载期间提供了额外的定制功能。.. automodule:: torch.distributed.checkpoint.stateful
- class torch.distributed.checkpoint.stateful.Stateful(*args, **kwargs)[source][source]¶
用于可检查点和恢复的对象的 Stateful 协议。
这个例子展示了如何使用 Pytorch 分布式检查点保存 FSDP 模型。
以下类型定义了在检查点期间使用的 IO 接口
- class torch.distributed.checkpoint.StorageReader[source][source]¶
由
load_state_dict
使用的接口,用于从存储中读取。一个 StorageReader 实例在分布式检查点中同时充当协调者和跟随者。作为初始化的一部分,每个实例都会被告知其角色。
子类应预期由
load_state_dict
进行的以下调用序列(所有 rank) 如果用户传递了有效的 checkpoint_id,则设置 checkpoint_id。
(所有 rank) read_metadata()
(所有 rank) set_up_storage_reader()
(所有 rank) prepare_local_plan()
(协调者) prepare_global_plan()
(所有 rank) read_data()
- abstract prepare_global_plan(plans)[source][source]¶
执行存储加载的集中式规划。
此方法仅在协调者实例上被调用。
虽然此方法可以生成完全不同的计划,但推荐的方式是在 LoadPlan::storage_data 中存储特定于存储的数据。
- 参数
plans (list[torch.distributed.checkpoint.planner.LoadPlan]) – 一个
LoadPlan
实例列表,每个 rank 一个。- 返回
在存储全局规划之后转换后的
LoadPlan
列表- 返回类型
- abstract prepare_local_plan(plan)[source][source]¶
执行特定于存储的本地规划。
虽然此方法可以生成完全不同的计划,但推荐的方式是在 LoadPlan::storage_data 中存储特定于存储的数据。
- abstract read_data(plan, planner)[source][source]¶
使用
plan
读取所有项,并使用planner
来解析数据。子类应该调用
LoadPlanner::load_bytes
来反序列化一个 BytesIO 对象到正确的位置。子类应该调用
LoadPlanner::resolve_tensor
来获取应该加载数据进去的 tensors 的访问权限。适当调度任何所需的跨设备复制是 StorageLayer 的责任。
- 参数
plan (LoadPlan) – 要执行的本地计划。
planner (LoadPlanner) – 用于解析项的 planner 对象。
- 返回
一个在所有读取完成后完成的 future。
- 返回类型
Future[None]
- abstract reset(checkpoint_id=None)[source][source]¶
调用表明即将进行全新的检查点读取。如果用户为此次检查点读取设置了 checkpoint_id,则可能存在一个 checkpoint_id。checkpoint_id 的含义取决于存储。它可以是一个文件夹/文件的路径,或键值存储的键。
- 参数
checkpoint_id (Union[str, os.PathLike, None]) – 此检查点实例的 ID。checkpoint_id 的含义取决于存储。它可以是一个文件夹或文件的路径。如果存储更像键值存储,它也可以是一个键。(默认值:
None
)
- class torch.distributed.checkpoint.StorageWriter[source][source]¶
由
save_state_dict
使用的接口,用于写入存储。一个 StorageWriter 实例在分布式检查点中同时充当协调者和跟随者。作为初始化的一部分,每个实例都会被告知其角色。
子类应预期以下调用序列。
(所有 rank) 如果用户传递了有效的 checkpoint_id,则设置 checkpoint_id。
(所有 rank) set_up_storage_writer()
(所有 rank) prepare_local_plan()
(协调者) prepare_global_plan()
(所有 rank) write_data()
(协调者) finish()
- abstract finish(metadata, results)[source][source]¶
写入元数据并将当前检查点标记为成功。
用于序列化 metadata 的实际格式/模式是一个实现细节。唯一的要求是它可以恢复到相同的对象图。
- abstract prepare_global_plan(plans)[source][source]¶
执行存储的集中式规划。
此方法仅在协调者实例上被调用。
虽然此方法可以生成完全不同的计划,但推荐的方式是在 SavePlan::storage_data 中存储特定于存储的数据。
- 参数
plans (list[torch.distributed.checkpoint.planner.SavePlan]) – 一个
SavePlan
实例列表,每个 rank 一个。- 返回
在存储全局规划之后转换后的
SavePlan
列表- 返回类型
- abstract prepare_local_plan(plan)[source][source]¶
执行特定于存储的本地规划。
虽然此方法可以生成完全不同的计划,但推荐的方式是在 SavePlan::storage_data 中存储特定于存储的数据。
- abstract reset(checkpoint_id=None)[source][source]¶
调用表明即将进行全新的检查点写入。如果用户为此次检查点写入设置了 checkpoint_id,则可能存在一个 checkpoint_id。checkpoint_id 的含义取决于存储。它可以是一个文件夹/文件的路径,或键值存储的键。
- 参数
checkpoint_id (Union[str, os.PathLike, None]) – 此检查点实例的 ID。checkpoint_id 的含义取决于存储后端。它可以是文件夹或文件的路径,如果存储是键值存储,它也可以是键。(默认值:
None
)
- abstract set_up_storage_writer(is_coordinator)[source][source]¶
初始化此实例。
- 参数
is_coordinator (bool) – 此实例是否负责协调检查点。
- storage_meta()[source][source]¶
返回特定于存储的元数据。这用于在检查点中存储额外信息,这对于提供请求级别的可观测性很有用。在保存调用期间,StorageMeta 会传递给
SavePlanner
。默认返回 None。TODO: 提供一个例子
- 返回类型
Optional[StorageMeta]
- abstract classmethod validate_checkpoint_id(checkpoint_id)[source][source]¶
检查给定的 checkpoint_id 是否受存储支持。这使我们能够启用自动存储选择。
- 返回类型
- abstract write_data(plan, planner)[source][source]¶
使用
plan
写入所有项,并使用planner
来解析数据。子类应该调用
SavePlanner::resolve_data
对 plan 中的每个项,来获取要写入的底层对象的访问权限。子类应该延迟调用 resolve_data,因为它可能会分配内存。对于 tensors,做出以下假设
- 他们可能在任何设备上,包括与
WriteItem::tensor_data
上的不匹配- 他们可能是 views 或非连续的。只需保存 projection。
- 参数
plan (SavePlan) – 要执行的保存计划。
planner (SavePlanner) – 用于将项解析为数据的 planner 对象。
- 返回
一个完成后返回 WriteResult 列表的 future。
- 返回类型
Future[list[torch.distributed.checkpoint.storage.WriteResult]]
以下类型定义了在检查点期间使用的 planner 接口
- class torch.distributed.checkpoint.LoadPlanner[source][source]¶
定义了由 load_state_dict 用于规划加载过程的协议的抽象类。
LoadPlanner 是可用于定制整个加载过程的 stateful 对象。
LoadPlanner 作为 state_dict 的访问代理,因此对其进行的任何转换将对整个过程可见。
一个 planner 子类可以在 load_state_dict 期间预期以下调用序列
- set_up_planner - 在所有 rank 上调用。
标志着加载检查点的开始。
- create_local_plan - 在所有 rank 上调用。
处理 state_dict 并生成一个 LoadPlan,该 LoadPlan 将被发送用于全局规划。
- create_global_plan - 仅在协调者 rank 上调用。
接收来自所有 rank 的 LoadPlan 并做出任何全局决策。
- load_bytes - 在每个 rank 上多次调用
这对于 state_dict 中的每个非 tensor 值调用一次。
- resolve_tensor 和 commit_tensor - 在每个 rank 上多次调用
它们对于 state_dict 中的每个 Tensor 值成对调用。
建议用户继承 DefaultLoadPlanner 而不是直接使用此接口,因为大多数更改可以通过修改单个方法来表达。
有两种常见的扩展模式
- 重写 state_dict。这是扩展加载过程的最简单方法,因为它不需要理解 LoadPlan 工作原理的复杂细节。我们需要保留对原始 state_dict 的引用,因为加载是原地发生的,所以我们需要能够原地执行它。
>>> class RenamePlanner(DefaultLoadPlanner): >>> def set_up_planner( >>> self, >>> state_dict: STATE_DICT_TYPE, >>> metadata: Metadata, >>> is_coordinator: bool, >>> ) -> None: >>> self.original_state_dict = state_dict >>> state_dict = {"foo_" + k: v for k, v in state_dict.items()} >>> >>> if self.flatten_sharded_tensors: >>> state_dict = _flatten_sharded_tensors(state_dict) >>> >>> if self.flatten_state_dict: >>> state_dict, self.mappings = flatten_state_dict(state_dict) >>> >>> self.state_dict = state_dict >>> self.metadata = metadata >>> self.is_coordinator = is_coordinator >>> >>> def load_bytes(self, read_item, value): >>> # Remove the "foo_" prefix >>> self.original_state_dict[read_item.dest_index.fqn[4:]] = torch.load(value, weights_only=False)
- 修改 resolve_tensor 和 commit_tensor 来处理加载时转换。
>>> class MetaModelMaterialize(DefaultSavePlanner): >>> def resolve_tensor(self, read_item): >>> tensor = super().resolve_tensor(read_item) >>> return torch.empty_like(tensor, device="cpu") >>> >>> def commit_tensor(self, read_item, tensor): >>> self.state_dict[read_item.dest_index.fqn] = tensor
- abstract commit_tensor(read_item, tensor)[source][source]¶
在 StorageReader 完成将数据加载到
tensor
中后调用。提供的 tensor 与由对
resolve_tensor
的调用返回的 tensor 是同一个。仅当此 LoadPlanner 需要在将其复制回 state_dict 中的 tensor 之前后处理tensor
时才需要此方法。tensor 的内容将遵循其设备同步模型。
- abstract create_global_plan(global_plan)[source][source]¶
计算全局加载计划并返回每个 rank 的计划。
. 注意:这仅在协调者 rank 上调用
- abstract create_local_plan()[source][source]¶
基于 state_dict 和由 set_up_planner 提供的 metadata 创建一个 LoadPlan。
. 注意:这在每个 rank 上调用。
- 返回类型
- abstract load_bytes(read_item, value)[source][source]¶
加载由
read_item
和value
描述的项。此方法预期会原地修改底层 state_dict。
value
的内容由用于生成正在加载的检查点的 SavePlanner 定义。
- resolve_bytes(read_item)[source][source]¶
返回由 StorageReader 使用来加载 read_item 的 BytesIO。
该 BytesIO 应该与底层 state_dict 中的一个 alias,因为 StorageReader 将替换其内容。
- 返回类型
BytesIO
- class torch.distributed.checkpoint.LoadPlan(items: list[torch.distributed.checkpoint.planner.ReadItem], storage_data: Any = None, planner_data: Any = None)[source][source]¶
- class torch.distributed.checkpoint.ReadItem(type: torch.distributed.checkpoint.planner.LoadItemType, dest_index: torch.distributed.checkpoint.metadata.MetadataIndex, dest_offsets: torch.Size, storage_index: torch.distributed.checkpoint.metadata.MetadataIndex, storage_offsets: torch.Size, lengths: torch.Size)[source][source]¶
- class torch.distributed.checkpoint.SavePlanner[source][source]¶
定义了
save_state_dict
用于规划保存过程的协议的抽象类。SavePlanner 是有状态的对象,可用于自定义整个保存过程。
SavePlanner 作为
state_dict
的访问代理,因此对其进行的任何转换将对整个过程可见。planner 子类在
save_state_dict
期间可预期以下调用序列:- set_up_planner - 在所有 rank 上调用。
表示检查点保存的开始。
- create_local_plan - 在所有 rank 上调用。
处理
state_dict
并生成一个 SavePlan,该 SavePlan 将用于全局规划。
- create_global_plan - 仅在协调者 rank 上调用。
接收来自所有 rank 的 SavePlan 并做出任何全局决策。
finish_plan
- 在所有 rank 上调用。这让每个 rank 有机会根据全局规划决策进行调整。
resolve_data
- 在每个 rank 上多次调用在 state_dict 中查找一个值,供存储层写入。
建议用户继承
DefaultSavePlanner
而非直接继承此接口,因为大多数更改可以通过修改单个方法来实现。有 3 种常见的扩展模式:
重写
state_dict
。这是扩展保存过程最简单的方式,因为它不需要理解 SavePlan 工作原理的复杂性。>>> class RenamePlanner(DefaultSavePlanner): >>> def set_up_planner( >>> self, >>> state_dict: STATE_DICT_TYPE, >>> storage_meta: Optional[StorageMeta], >>> is_coordinator: bool, >>> ) -> None: >>> # prefix all keys with `foo_`` >>> super().set_up_planner({"foo_" + k: v for k, v in state_dict.items()}, storage_meta, is_coordinator)
同时修改局部计划和查找。这在需要精细控制数据如何持久化时非常有用。
>>> class FP16Planner(DefaultSavePlanner): >>> def create_local_plan(self): >>> plan = super().create_local_plan() >>> for p in plan: >>> if p.tensor_data is not None: >>> p.tensor_data.properties.dtype = torch.float16 >>> return plan >>> >>> def resolve_data(self, write_item): >>> item = super().resolve_data(write_item) >>> return item if write_item.type == WriteItemType.BYTE_IO else item.to(torch.float16)
利用全局规划步骤来做出每个 rank 无法独立做出的中心决策。
>>> from itertools import zip_longest >>> from dataclasses import replace >>> class DDPLoadBalancingPlanner(DefaultSavePlanner): >>> # This uses the default local plan behavior of having all non-sharded writes in rank 0 >>> # This sample doesn't handle ShardedTensors >>> def create_global_plan(self, all_plans): >>> iters = [iter(all_plans[0].items)] * len(all_plans) >>> items_per_rank = [ >>> [item for item in items if item is not None] >>> for items in zip(*zip_longest(*iters), strict=True) >>> ] >>> all_plans = [ >>> replace(plan, items=items) >>> for plan, items in zip(all_plans, items_per_rank, strict=True) >>> ] >>> return super().create_global_plan(all_plans)
最后,有些 planner 需要在检查点中保存额外的元数据,这通过让每个 rank 在局部计划中贡献其数据项并由全局 planner 聚合它们来实现。
>>> class SaveExtraDataPlanner(DefaultSavePlanner): >>> def create_local_plan(self) -> SavePlan: >>> plan = super().create_local_plan() >>> return replace(plan, planner_data="per-rank-data") >>> >>> def create_global_plan(self, all_plans: List[SavePlan]) -> Tuple[List[SavePlan], Metadata]: >>> global_plan, metadata = super().create_global_plan(all_plans) >>> merged_data = [p.planner_data for p in global_plan] >>> metadata = replace(metadata, planner_data=merged_data) >>> return global_plan, metadata
- abstract create_global_plan(all_plans)[source][source]¶
计算全局检查点计划,并返回每个 rank 的局部计划。
这仅在 coordinator rank(协调者 rank)上调用。
- 返回类型
tuple[list[torch.distributed.checkpoint.planner.SavePlan], torch.distributed.checkpoint.metadata.Metadata]
- abstract create_local_plan()[source][source]¶
计算当前 rank 的保存计划。
这将被聚合并传递给
create_global_plan
。planner 特有的数据可以通过SavePlan::planner_data
传递。这在所有 rank 上调用。
- 返回类型
- abstract finish_plan(new_plan)[source][source]¶
合并由 create_local_plan 创建的计划和 create_global_plan 的结果。
这在所有 rank 上调用。
- 返回类型
- abstract resolve_data(write_item)[source][source]¶
转换并准备来自
state_dict
的write_item
用于存储,确保幂等性和线程安全。在
state_dict
中查找与write_item
关联的对象,并在存储层使用它之前应用任何转换(例如序列化)。在每个 rank 上多次调用,至少在最终 SavePlan 中的每个 WriteItem 调用一次。
此方法应为幂等且线程安全。StorageWriter 实现可以根据需要自由调用它。
任何需要分配内存的转换都应在此方法调用时进行延迟执行,以减少检查点操作所需的峰值内存。
返回张量时,它们可以在任何设备上,可以是任何格式,也可以是视图。由存储层负责确定如何保存它们。
- class torch.distributed.checkpoint.SavePlan(items: list[torch.distributed.checkpoint.planner.WriteItem], storage_data: Any = None, planner_data: Any = None, usable: bool =True)[source][source]¶
- class torch.distributed.checkpoint.planner.WriteItem(index, type, tensor_data=None)[source][source]¶
一个数据类,包含关于需要写入存储的内容的信息。
我们提供一个基于文件系统的存储层。
- class torch.distributed.checkpoint.FileSystemReader(path, _extension_registry=None)[source][source]¶
- class torch.distributed.checkpoint.FileSystemWriter(path, single_file_per_rank=True, sync_files=True, thread_count=1, per_thread_copy_ahead=10000000, cache_staged_state_dict=False, overwrite=True, _extensions=None)[source][source]¶
使用文件 IO 的 StorageWriter 基本实现。
此实现做出以下假设和简化:
检查点路径是一个空目录或不存在的目录。
文件创建是原子性的。
检查点由每个写入请求一个文件,再加上一个包含序列化元数据的 .metadata 文件组成。
我们提供了 LoadPlanner 和 SavePlanner 的默认实现,它们可以处理所有 torch.distributed 构造,例如 FSDP、DDP、ShardedTensor 和 DistributedTensor。
- class torch.distributed.checkpoint.DefaultSavePlanner(flatten_state_dict=True, flatten_sharded_tensors=True, dedup_replicated_tensors=None, dedup_save_to_lowest_rank=False, enable_plan_caching=False)[source][source]¶
- class torch.distributed.checkpoint.DefaultLoadPlanner(flatten_state_dict=True, flatten_sharded_tensors=True, allow_partial_load=False)[source][source]¶
在 LoadPlanner 的基础上添加了多个功能的
DefaultLoadPlanner
。特别地,它添加了以下功能:
flatten_state_dict
: 处理带有嵌套字典的state_dict
。flatten_sharded_tensors
: 用于 2D 并行模式下的 FSDP。allow_partial_load
: 如果为 False,则当state_dict
中存在而检查点中不存在某个键时,将引发运行时错误。
由于历史设计决策,FSDP 和 DDP 的 state dictionary 可能具有不同的键或完全限定名(例如 layer1.weight),即使原始未并行化的模型是相同的。此外,FSDP 提供各种类型的模型 state dictionary,例如完整和分片 state dictionary。另外,optimizer state dictionary 使用参数 ID 而非完全限定名来标识参数,在使用并行化时可能导致问题(例如流水线并行)。
为了解决这些挑战,我们提供了一系列 API,供用户轻松管理 state_dict。get_model_state_dict() 返回一个模型 state dictionary,其键与未并行化模型 state dictionary 返回的键一致。类似地,get_optimizer_state_dict() 提供 optimizer state dictionary,其键在所有应用的并行化中保持一致。为了实现这种一致性,get_optimizer_state_dict() 将参数 ID 转换为与未并行化模型 state dictionary 中找到的完全相同的完全限定名。
请注意,这些 API 返回的结果可以直接用于 torch.distributed.checkpoint.save() 和 torch.distributed.checkpoint.load() 方法,无需任何额外的转换。
提供了 set_model_state_dict() 和 set_optimizer_state_dict() 方法,用于加载由其各自的 getter API 生成的模型和 optimizer 的 state_dict。
请注意,set_optimizer_state_dict() 只能在 optimizer 上调用 backward() 之前或调用 step() 之后进行调用。
请注意,此功能是实验性的,并且 API 签名将来可能会更改。
- torch.distributed.checkpoint.state_dict.get_state_dict(model, optimizers, *, submodules=None, options=None)[source][source]¶
返回模型 state_dict 和 optimizers state_dict。
get_state_dict
可以处理由 PyTorch FSDP/fully_shard、DDP/replicate、tensor_parallel/parallelize_module 并行化的任何模块,以及这些并行化模式的任何组合。get_state_dict
的主要功能是:1.) 返回一个可以根据不同数量的训练器和/或不同并行化模式进行重新分片 (resharded) 的模型和 optimizer state_dict。2.) 隐藏特定于并行化的 state_dict API,用户无需调用这些 API。3.) 对结果 state_dict 进行健全性检查 (sanity checking)。结果 state dictionary 的键是规范的 FQN (完全限定名 Fully Qualified Names)。规范 FQN 指的是基于参数在 nn.Module 层次结构中的位置的 FQN。更具体地说,当模块未被任何并行化模式分布式时,参数的规范 FQN 是由
module.named_parameters()
或module.named_buffers()
返回的 FQN。由于 optimizer 在内部使用参数 ID 来表示参数,调用此 API 时将会进行从参数 ID 到规范 FQN 的转换。get_state_dict
也可以处理一个未并行化的模块。在这种情况下,get_state_dict
只执行一个功能——将 optimizer 参数 ID 转换为规范 FQN。示例
>>> import torch >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP >>> from torch.nn.parallel import DistributedDataParallel as DDP >>> from torch.distributed.checkpoint.state_dict import get_state_dict
>>> fsdp_model = FSDP(copy.deepcopy(model)) >>> fsdp_optim = torch.optim.Adam(model.parameters(), lr=1e-3) >>> ddp_model = DDP(copy.deepcopy(model)) >>> ddp_optim = torch.optim.Adam(model.parameters(), lr=1e-3)
>>> ddp_state_dict, ddp_optim_state_dict = get_state_dict(ddp_model, ddp_optim) >>> fsdp_state_dict, fsdp_optim_state_dict = get_state_dict( ... fsdp_model, fsdp_optim ... )
>>> # if we simply call ddp_model.state_dict() and fsdp_model.state_dict(), >>> # the asserts will fail. >>> assert ddp_state_dict == fsdp_state_dict >>> assert ddp_optim_state == fsdp_optim_state_dict
- 参数
model (nn.Module) – 表示模型的 nn.Module 对象。
optimizers (Union[None, Optimizer, Iterable[Optimizer]]) – 用于优化
model
的 optimizer 对象。submodules (已废弃) – Optional[set[nn.Module]]:只返回属于这些子模块的模型参数。
options (StateDictOptions) – 用于控制如何返回模型 state_dict 和 optimizer state_dict 的选项。详情请参阅 StateDictOptions。
- 返回
包含模型 state_dict 和 optimizer state_dict 的
Tuple
。- 返回类型
- torch.distributed.checkpoint.state_dict.get_model_state_dict(model, *, submodules=None, options=None)[source][source]¶
返回
model
的模型状态字典(state_dict)。详细用法请参见
get_state_dict
。- 参数
model (nn.Module) – 表示模型的 nn.Module 对象。
submodules (已废弃) – Optional[set[nn.Module]]:只返回属于这些子模块的模型参数。
options (StateDictOptions) – 用于控制如何返回模型 state_dict 和 optimizer state_dict 的选项。详情请参阅 StateDictOptions。
- 返回
model
的状态字典。- 返回类型
- torch.distributed.checkpoint.state_dict.get_optimizer_state_dict(model, optimizers, *, submodules=None, options=None)[source][source]¶
返回优化器的组合状态字典(state_dict)。
详细用法请参见
get_state_dict
。- 参数
model (nn.Module) – 表示模型的 nn.Module 对象。
optimizers (Union[None, Optimizer, Iterable[Optimizer]]) – 用于优化
model
的 optimizer 对象。submodules (已废弃) – Optional[set[nn.Module]]:只返回属于这些子模块的模型参数。
options (StateDictOptions) – 用于控制如何返回模型 state_dict 和 optimizer state_dict 的选项。详情请参阅 StateDictOptions。
- 返回
optimizers
的状态字典。- 返回类型
OptimizerStateType
- torch.distributed.checkpoint.state_dict.set_state_dict(model, optimizers, *, model_state_dict, optim_state_dict, options=None)[source][source]¶
加载模型状态字典和优化器状态字典。
get_state_dict
的对应函数,用于将状态字典设置到模型和优化器中。给定的model_state_dict
和optim_state_dict
不必是get_state_dict
返回的,但必须满足以下要求:1) 所有 FQNs 都是get_state_dict
中定义的规范 FQNs,2) 如果张量是分片的,它必须是 ShardedTensor 或 DTensor,3) 优化器状态字典不能包含参数 ID;键应该是规范的 FQNs。- 警告:
set_state_dict
只能在backward()
调用之前或在优化器上调用step()
之后调用。否则,优化器状态将无法正确初始化。
- 参数
model (nn.Module) – 表示模型的 nn.Module 对象。
optimizers (Union[Optimizer, Iterable[Optimizer]]) – 用于优化
model
的优化器。model_state_dict (Dict[str, ValueType]) – (Union[Dict[nn.Module, Dict[str, ValueType]], Dict[str, ValueType]]): 要加载的模型状态字典。如果
model_state_dict
的键是 nn.Module,则该键是model
的一个子模块,值应该是该子模块的状态字典。加载状态字典时,子模块的前缀将添加到状态字典中。optim_state_dict (OptimizerStateType) – OptimizerStateType: 要加载的优化器状态字典。
options (StateDictOptions) – 控制如何加载模型状态字典和优化器状态字典的选项。详情请参见 StateDictOptions。
- 返回
missing_keys 是一个包含模型状态字典中缺失键的字符串列表。
unexpected_keys 是一个包含模型状态字典中意外键的字符串列表。
- 返回类型
NamedTuple
包含missing_keys
和unexpected_keys
字段
- 警告:
- torch.distributed.checkpoint.state_dict.set_model_state_dict(model, model_state_dict, *, options=None)[source][source]¶
加载模型状态字典。
get_model_state_dict
的对应函数,用于将状态字典设置到模型中。详细用法请参见set_state_dict
。- 参数
model (nn.Module) – 表示模型的 nn.Module 对象。
model_state_dict (Dict[str, ValueType]) – (Dict[str, ValueType]): 要加载的模型状态字典。如果
model_state_dict
的键是 nn.Module,则该键是model
的一个子模块,值应该是该子模块的状态字典。加载状态字典时,子模块的前缀将添加到状态字典中。options (StateDictOptions) – 控制如何加载模型状态字典和优化器状态字典的选项。详情请参见 StateDictOptions。
- 返回
missing_keys 是一个包含缺失键的字符串列表
unexpected_keys 是一个包含意外键的字符串列表
- 返回类型
NamedTuple
包含missing_keys
和unexpected_keys
字段
- torch.distributed.checkpoint.state_dict.set_optimizer_state_dict(model, optimizers, optim_state_dict, *, options=None)[source][source]¶
加载优化器状态字典。
get_optimizer_state_dict
的对应函数,用于将状态字典设置到优化器中。详细用法请参见set_state_dict
。- 警告:
set_optimizer_state_dict
只能在backward()
调用之前或之后调用 step()
。否则,优化器状态将无法正确初始化。
- 参数
model (nn.Module) – 表示模型的 nn.Module 对象。
optimizers (Union[Optimizer, Iterable[Optimizer]]) – 用于优化
model
的优化器。optim_state_dict (OptimizerStateType) – OptimizerStateType: 要加载的优化器状态字典。
options (StateDictOptions) – 控制如何加载模型状态字典和优化器状态字典的选项。详情请参见 StateDictOptions。
- 返回
无
- 返回类型
无
- 警告:
- class torch.distributed.checkpoint.state_dict.StateDictOptions(full_state_dict=False, cpu_offload=False, ignore_frozen_params=False, keep_submodule_prefixes=True, strict=True, broadcast_from_rank0=False, flatten_optimizer_state_dict=False, dsd_fqn_modifiers='_fqn_modifiers')[source][source]¶
此数据类指定 get_state_dict/set_state_dict 的工作方式。
full_state_dict
:如果设置为 True,返回的状态字典中的所有张量都将被收集。返回的状态字典中将不包含 ShardedTensor 和 DTensor。cpu_offload
:将所有张量卸载到 CPU。为防止 CPU OOM,如果full_state_dict
也为 True,则只有 rank0 会获取状态字典,所有其他 ranks 将获取空状态字典。ignore_frozen_params
:如果值为 True,返回的状态字典将不包含任何冻结参数——requires_grad
为 False 的参数。默认值为 False。keep_submodule_prefixes
(已弃用):当submodules
不为 None 时,此选项指示是否从状态字典的键中保留子模块前缀。例如,如果子模块是module.pretrain
,并且参数的完整 FQN 是pretrain.layer1.weight
。当此选项为 True 时,返回的状态字典中该参数的键将是pretrain.layer1.weight
。如果选项为 False,键将是layer1.weight
。请注意,如果keep_submodule_prefixes
为 False,则可能存在冲突的 FQNs,因此submodules
中应该只有一个子模块。strict
:set_state_dict
调用 model.load_state_dict() 时的strict
选项。broadcast_from_rank0
:当选项为 True 时,rank0 将接收一个完整状态字典,并将状态字典/optim_state_dict 中的张量逐个广播到其他 ranks。其他 ranks 将接收张量并根据模型和优化器中的本地分片进行分片。
full_state_dict
在使用此选项时必须设置为 True。此选项目前仅支持 DTensor,不支持旧的 ShardedTensor。
对于习惯于使用和共享 torch.save 格式模型的用户,提供了以下方法,它们提供了用于在格式之间转换的离线工具。
- torch.distributed.checkpoint.format_utils.dcp_to_torch_save(dcp_checkpoint_dir, torch_save_path)[source][source]¶
给定一个包含 DCP 检查点(checkpoint)的目录,此函数将将其转换为 Torch 保存文件。
- 参数
警告
为避免 OOM,建议仅在单个 rank 上运行此函数。
- torch.distributed.checkpoint.format_utils.torch_save_to_dcp(torch_save_path, dcp_checkpoint_dir)[source][source]¶
给定 Torch 保存文件的位置,将其转换为 DCP 检查点。
- 参数
警告
为避免 OOM,建议仅在单个 rank 上运行此函数。
以下类也可以用于在线加载和重新分片 torch.save 格式的模型。
- class torch.distributed.checkpoint.format_utils.BroadcastingTorchSaveReader(checkpoint_id=None, coordinator_rank=0)[source][source]¶
用于读取 Torch Save 文件的 StorageReader。此读取器将在协调器(coordinator)rank 上读取整个检查点,然后将每个张量广播并分片到所有 ranks。
. N.B. 旨在与 DynamicMetaLoadPlanner 一起使用。
警告
当前实现仅支持加载张量。
>>> sd = {"mode": model} >>> dcp.load( >>> sd, >>> storage_reader=BroadcastingTorchSaveReader(), >>> planner=DynamicMetaLoadPlanner(), >>> checkpoint_id="path_to_model.pt" >>> )
- class torch.distributed.checkpoint.format_utils.DynamicMetaLoadPlanner(flatten_state_dict=True, flatten_sharded_tensors=True, allow_partial_load=False)[source][source]¶
DefaultLoadPlanner 的扩展,它基于传入的状态字典创建新的 Metadata 对象,避免了从磁盘读取元数据。这在读取没有元数据文件的格式(如 Torch Save 文件)时非常有用。
. N.B. 旨在与 BroadcastingTorchSaveReader 一起使用。
警告
当前实现仅支持加载张量。
>>> sd = {"mode": model} >>> dcp.load( >>> sd, >>> storage_reader=BroadcastingTorchSaveReader(), >>> planner=DynamicMetaLoadPlanner(), >>> checkpoint_id="path_to_model.pt" >>> )
以下实验性接口提供了在生产环境中改进可观测性的功能