分布式检查点 - torch.distributed.checkpoint¶
分布式检查点 (DCP) 支持从多个排名中并行加载和保存模型。它处理加载时重新分片,这使得能够在一个集群拓扑中保存并在另一个集群拓扑中加载。
DCP 与 torch.save 和 torch.load 在几个重要方面有所不同
它为每个检查点生成多个文件,每个排名至少一个文件。
它在原地操作,这意味着模型应该首先分配其数据,DCP 使用该存储而不是重新分配。
加载和保存检查点的入口点如下
- torch.distributed.checkpoint.state_dict_saver.save(state_dict, *, checkpoint_id=None, storage_writer=None, planner=None, process_group=None)[source]¶
以 SPMD 样式保存分布式模型。
此函数与
torch.save()
不同,因为它处理ShardedTensor
和DTensor
,每个进程只保存其本地分片。对于每个
Stateful
对象(同时具有state_dict
和load_state_dict
),save 会在序列化之前调用state_dict
。警告
对于保存的 state_dict,跨 PyTorch 版本的向后兼容性没有保证。
警告
如果使用 process_group 参数,请确保只有其进程调用 save_state_dict,并且 state_dict 中的所有数据都属于它。
注意
为 FSDP 的 ShardingStrategy.HYBRID_SHARD 保存检查点时,只有一个分片组应该调用 save_state_dict,并且需要传入相应的进程组。
注意
- 如果没有可用的进程组,此函数假设意图是在本地进程中保存
state_dict。
- 参数
state_dict (Dict[str, Any]) – 要保存的 state_dict。
checkpoint_id (Union[str, os.PathLike, None]) – 此检查点实例的 ID。checkpoint_id 的含义取决于存储方式。它可以是文件夹或文件的路径。如果存储是键值存储,它也可以是键。(默认值:
None
)storage_writer (Optional[StorageWriter]) – 用于执行写入的 StorageWriter 实例。如果未指定,DCP 将根据 checkpoint_id 自动推断写入器。如果 checkpoint_id 也为 None,则会引发异常。(默认值:
None
)planner (Optional[SavePlanner]) – SavePlanner 实例。如果未指定,将使用默认规划器。(默认值:
None
)process_group (Optional[ProcessGroup]) – 用于跨等级同步的 ProcessGroup。(默认值:
None
)
- 返回值
已保存检查点的元数据对象。
- 返回类型
Metadata
示例
>>> my_model = MyModule()
>>> state_dict = {"model": my_model}
>>> fs_storage_writer = torch.distributed.checkpoint.FileSystemWriter("/checkpoint/1") >>> torch.distributed.checkpoint.save( >>> state_dict=state_dict, >>> storage_writer=fs_storage_writer, >>> )
注意
save_state_dict 使用集体操作来协调跨等级的写入。对于基于 NCCL 的进程组,对象内部的张量表示必须在通信发生之前移动到 GPU 设备。在这种情况下,使用的设备由
torch.cuda.current_device()
给出,用户有责任确保通过torch.cuda.set_device()
设置每个等级都有一个单独的 GPU。
- torch.distributed.checkpoint.state_dict_saver.async_save(state_dict, *, checkpoint_id=None, storage_writer=None, planner=None, process_group=None)[source]¶
save_state_dict
的异步版本。此代码首先在 CPU 上取消分段 state_dict,然后在单独的线程中调用 save。警告
此功能为实验性功能,可能随时更改。
- 参数
state_dict (Dict[str, Any]) – 要保存的 state_dict。
checkpoint_id (Union[str, os.PathLike, None]) – 此检查点实例的 ID。checkpoint_id 的含义取决于存储方式。它可以是文件夹或文件的路径。如果存储是键值存储,它也可以是键。(默认值:
None
)storage_writer (Optional[StorageWriter]) – 用于执行写入的 StorageWriter 实例。如果未指定,DCP 将根据 checkpoint_id 自动推断写入器。如果 checkpoint_id 也为 None,则会引发异常。(默认值:
None
)planner (Optional[SavePlanner]) – SavePlanner 实例。如果未指定,将使用默认规划器。(默认值:
None
)process_group (Optional[ProcessGroup]) – 用于跨等级同步的 ProcessGroup。(默认值:
None
)
- 返回值
将来保存 save 方法返回的元数据对象。
- 返回类型
示例
>>> my_model = MyModule()
>>> state_dict = {"model": my_model}
>>> fs_storage_writer = torch.distributed.checkpoint.FileSystemWriter("/checkpoint/1") >>> checkpoint_future = torch.distributed.checkpoint.async_save( >>> state_dict=state_dict, >>> storage_writer=fs_storage_writer, >>> ) >>> >>> # ... do some work ... >>> >>> checkpoint_future.result()
- torch.distributed.checkpoint.state_dict_saver.save_state_dict(state_dict, storage_writer, process_group=None, coordinator_rank=0, no_dist=False, planner=None)[source]¶
此方法已弃用。请切换到“save”。
- 返回类型
Metadata
- torch.distributed.checkpoint.state_dict_loader.load(state_dict, *, checkpoint_id=None, storage_reader=None, planner=None, process_group=None)[source]¶
以 SPMD 样式加载分布式
state_dict
。每个进程将尝试读取满足请求的 state_dict 所需的最少数据。在加载
ShardedTensor
或DTensor
实例时,每个进程只读取其本地分片的相关数据。对于每个
Stateful
对象(同时具有state_dict
和load_state_dict
),load 将首先调用state_dict
,然后尝试反序列化,最后在反序列化完成后调用load_state_dict
。警告
在调用此函数之前,必须在
state_dict
中分配所有张量到其目标设备。所有非张量数据使用 torch.load() 加载,并在 state_dict 上就地修改。
警告
用户必须在根模块上调用 load_state_dict 以确保加载后处理和非张量数据正确传播。
- 参数
state_dict (Dict[str, Any]) – 要保存的 state_dict。
checkpoint_id (Union[str, os.PathLike, None]) – 此检查点实例的 ID。checkpoint_id 的含义取决于存储方式。它可以是文件夹或文件的路径。如果存储是键值存储,它也可以是键。(默认值:
None
)storage_reader (可选[StorageReader]) – 用于执行读取的 StorageWriter 实例。如果未指定,DCP 将根据 checkpoint_id 自动推断读取器。如果 checkpoint_id 也为 None,则会引发异常。 (默认值:
None
)planner (可选[LoadPlanner]) – LoadPlanner 实例。如果未指定,将使用默认规划器。 (默认值:
None
)process_group (Optional[ProcessGroup]) – 用于跨等级同步的 ProcessGroup。(默认值:
None
)
- 返回值
无。
- 返回类型
无
- 示例
>>> my_model = MyModule() >>> optimizer = Adagrad(my_model.parameters()) >>> model_state_dict = my_model.state_dict() >>> fs_storage_reader = torch.distributed.checkpoint.FileSystemReader("/checkpoint/1")
>>> torch.distributed.checkpoint.load_state_dict( >>> state_dict=model_state_dict, >>> storage_reader=fs_storage_reader, >>> )
>>> # module.load_state_dict() function might have customized steps >>> # to flush the state_dict, must call it to >>> # ensure correct behavior. >>> my_model.load_state_dict(model_state_dict)
注意
load_state_dict 使用集体通信来协调跨秩的读取。对于基于 NCCL 的进程组,对象的内部张量表示必须在通信发生之前移动到 GPU 设备。在这种情况下,使用的设备由
torch.cuda.current_device()
给出,用户有责任确保通过torch.cuda.set_device()
设置此设备,以便每个秩都有一个单独的 GPU。
- torch.distributed.checkpoint.state_dict_loader.load_state_dict(state_dict, storage_reader, process_group=None, coordinator_rank=0, no_dist=False, planner=None)[source]¶
此方法已弃用。请切换到“load”。
除了上述入口点之外,Stateful 对象(如下所述)在保存/加载期间提供了额外的自定义功能 .. automodule:: torch.distributed.checkpoint.stateful
- class torch.distributed.checkpoint.stateful.Stateful(*args, **kwargs)[source]¶
用于可以检查点和恢复的对象的 Stateful 协议。
此 示例 展示了如何使用 Pytorch Distributed Checkpoint 保存 FSDP 模型。
以下类型定义了检查点期间使用的 IO 接口
- class torch.distributed.checkpoint.StorageReader[source]¶
load_state_dict
用于从存储读取的接口。一个 StorageReader 实例在分布式检查点中充当协调器和跟随者。作为初始化的一部分,每个实例都会被告知其角色。
load_state_dict
预计子类将按以下顺序调用(所有等级) 如果用户传递有效的 checkpoint_id,则设置 checkpoint_id。
(所有等级) read_metadata()
(所有进程) set_up_storage_reader()
(所有进程) prepare_local_plan()
(协调器) prepare_global_plan()
(所有进程) read_data()
- abstract prepare_global_plan(plans)[source]¶
执行存储加载的集中式规划。
此方法仅在协调器实例上调用。
虽然此方法可以生成完全不同的计划,但首选方法是在 LoadPlan::storage_data 中存储特定于存储的数据。
- abstract prepare_local_plan(plan)[source]¶
执行特定于存储的本地规划。
虽然此方法可以生成完全不同的计划,但推荐的方法是在 LoadPlan::storage_data 中存储特定于存储的数据。
- abstract read_data(plan, planner)[source]¶
使用
planner
从plan
中读取所有项以解析数据。子类应该调用
LoadPlanner::load_bytes
将 BytesIO 对象反序列化到正确的位置。子类应该调用
LoadPlanner::resolve_tensor
来访问应该加载数据的张量。StorageLayer 负责正确调度任何所需的跨设备复制操作。
- 参数
plan (LoadPlan) – 要在本地执行的计划
planner (LoadPlanner) – 用于解析项目的规划器对象。
- 返回值
所有读取完成后完成的 Future。
- 返回类型
Future[None]
- abstract reset(checkpoint_id=None)[source]¶
调用表示将要进行全新的检查点读取。如果用户为此次检查点读取设置了 checkpoint_id,则可能存在 checkpoint_id。checkpoint_id 的含义取决于存储。它可以是文件夹/文件的路径,也可以是键值存储的键。
- 参数
checkpoint_id (Union[str, os.PathLike, None]) – 此检查点实例的 ID。checkpoint_id 的含义取决于存储。它可以是文件夹或文件的路径。如果存储更像是键值存储,它也可以是键。(默认值:
None
)
- class torch.distributed.checkpoint.StorageWriter[source]¶
由
save_state_dict
用于写入存储的接口。一个 StorageWriter 实例在分布式检查点中充当协调器和跟随者。作为初始化的一部分,每个实例都会被告知其角色。
子类应该预期以下调用顺序。
(所有等级) 如果用户传递有效的 checkpoint_id,则设置 checkpoint_id。
(所有等级) set_up_storage_writer()
(所有进程) prepare_local_plan()
(协调器) prepare_global_plan()
(所有等级) write_data()
(协调器) finish()
- abstract finish(metadata, results)[source]¶
写入元数据并将当前检查点标记为成功。
用于序列化元数据的实际格式/模式是一个实现细节。唯一的要求是它可以恢复到相同的对象图。
- 抽象 prepare_global_plan(plans)[source]¶
执行存储的集中式规划。
此方法仅在协调器实例上调用。
虽然此方法可以生成完全不同的计划,但首选方法是在 SavePlan::storage_data 中存储特定于存储的数据。
- 抽象 prepare_local_plan(plan)[source]¶
执行特定于存储的本地规划。
虽然此方法可以生成完全不同的计划,但推荐的方法是在 SavePlan::storage_data 中存储特定于存储的数据。
- abstract reset(checkpoint_id=None)[source]¶
指示将要进行全新的检查点写入的调用。如果用户为此检查点写入设置了 checkpoint_id,则可能存在 checkpoint_id。checkpoint_id 的含义取决于存储。它可以是文件夹/文件的路径或键值存储的键。
- 参数
checkpoint_id (Union[str, os.PathLike, None]) – 此检查点实例的 ID。checkpoint_id 的含义取决于存储方式。它可以是文件夹或文件的路径。如果存储是键值存储,它也可以是键。(默认值:
None
)
- abstract set_up_storage_writer(is_coordinator)[source]¶
初始化此实例。
- 参数
is_coordinator (bool) – 该实例是否负责协调检查点。
- abstract classmethod validate_checkpoint_id(checkpoint_id)[source]¶
检查给定的 checkpoint_id 是否受存储支持。这允许我们启用自动存储选择。
- 返回类型
- abstract write_data(plan, planner)[source]¶
使用
planner
从plan
中写入所有项目以解析数据。子类应该对计划中的每个项目调用
SavePlanner::resolve_data
以访问要写入的底层对象。子类应该延迟调用 resolve_data,因为它会分配内存。在张量的情况下,做出以下假设
它们可能位于任何设备上,包括与
WriteItem::tensor_data
上的设备不匹配的设备它们可能是视图或非连续的。只需要保存投影。
- 参数
plan (SavePlan) – 要执行的保存计划。
planner (SavePlanner) – 用于解析项目到数据的规划器对象。
- 返回值
一个完成为 WriteResult 列表的未来。
- 返回类型
以下类型定义了检查点期间使用的规划器接口
- class torch.distributed.checkpoint.LoadPlanner[source]¶
定义 load_state_dict 用于规划加载过程的协议的抽象类。
LoadPlanner 是有状态的对象,可用于自定义整个加载过程。
LoadPlanner 充当 state_dict 的访问代理,因此对它的任何转换都将对整个过程可见。
规划器子类可以预期在 load_state_dict 期间以下列顺序调用
- set_up_planner - 在所有等级上调用。
表示加载检查点的开始。
- create_local_plan - 在所有等级上调用。
处理 state_dict 并生成一个 LoadPlan,该计划将被发送用于全局规划。
- create_global_plan - 仅在协调器等级上调用。
获取所有等级的 LoadPlan 并做出任何全局决策。
- load_bytes - 在每个等级上多次调用
这在 state_dict 中的每个非张量值调用一次。
- resolve_tensor 和 commit_tensor - 在每个等级上多次调用
它们在 state_dict 中的每个张量值成对调用。
建议用户扩展 DefaultLoadPlanner 而不是直接扩展此接口,因为大多数更改可以通过单个方法的更改来表达。
扩展通常有两种模式
重写 state_dict。这是扩展加载过程的最简单方法,因为它不需要理解 LoadPlan 如何工作的复杂性。我们需要保留对原始 state_dict 的引用,因为加载是在原地发生的,因此我们需要能够在原地执行它
>>> class RenamePlanner(DefaultLoadPlanner): >>> def set_up_planner(self, state_dict, metadata, is_coordinator): >>> self.original_state_dict = state_dict >>> state_dict = {"foo_" + k: v for k, v in state_dict.items()} >>> >>> if self.flatten_sharded_tensors: >>> state_dict = _flatten_sharded_tensors(state_dict) >>> >>> if self.flatten_state_dict: >>> state_dict, self.mappings = flatten_state_dict(state_dict) >>> >>> self.state_dict = state_dict >>> self.metadata = metadata >>> self.is_coordinator = is_coordinator >>> >>> def load_bytes(self, read_item, value): >>> # Remove the "foo_" prefix >>> self.original_state_dict[read_item.dest_index.fqn[4:]] = torch.load(value)
修改 resolve_tensor 和 commit_tensor 以处理加载时转换。
>>> class MetaModelMaterialize(DefaultSavePlanner): >>> def resolve_tensor(self, read_item): >>> tensor = super().resolve_tensor(read_item) >>> return torch.empty_like(tensor, device="cpu") >>> >>> def commit_tensor(self, read_item, tensor): >>> self.state_dict[read_item.dest_index.fqn] = tensor
- abstract commit_tensor(read_item, tensor)[source]¶
在 StorageReader 完成将数据加载到
tensor
中后调用一次。提供的张量与
resolve_tensor
调用返回的张量相同。仅当此 LoadPlanner 需要在将tensor
复制回 state_dict 中的张量之前对其进行后处理时,才需要此方法。张量的内容将遵循其设备同步模型。
- abstract create_local_plan()[source]¶
根据 set_up_planner 提供的 state_dict 和元数据创建 LoadPlan。
. 注意:这在每个等级上调用。
- 返回类型
- abstract load_bytes(read_item, value)[source]¶
加载由
read_item``和 ``value
描述的项目。此方法预计会就地修改底层 state_dict。
value
的内容由用于生成正在加载的检查点的 SavePlanner 定义。
- class torch.distributed.checkpoint.LoadPlan(items: List[torch.distributed.checkpoint.planner.ReadItem], storage_data: Any = None, planner_data: Any = None)[source]¶
- class torch.distributed.checkpoint.ReadItem(type: torch.distributed.checkpoint.planner.LoadItemType, dest_index: torch.distributed.checkpoint.metadata.MetadataIndex, dest_offsets: torch.Size, storage_index: torch.distributed.checkpoint.metadata.MetadataIndex, storage_offsets: torch.Size, lengths: torch.Size)[source]¶
- class torch.distributed.checkpoint.SavePlanner[source]¶
定义用于 save_state_dict 的协议的抽象类,用于规划保存过程。
SavePlanner 是有状态的对象,可用于自定义整个保存过程。
SavePlanner 充当 state_dict 的访问代理,因此对其进行的任何转换都将对整个过程可见。
规划器子类可以预期在 save_state_dict 期间执行以下调用序列
- set_up_planner - 在所有等级上调用。
表示检查点保存的开始。
- create_local_plan - 在所有等级上调用。
处理 state_dict 并生成一个 SavePlan,该计划将被发送以进行全局规划。
- create_global_plan - 仅在协调器等级上调用。
获取所有等级的 SavePlan 并做出任何全局决策。
- finish_plan - 在所有等级上调用。
这使每个等级有机会调整到全局规划决策。
- resolve_data - 在每个等级上多次调用
在 state_dict 上查找存储层要写入的值。
建议用户扩展 DefaultSavePlanner 而不是直接扩展此接口,因为大多数更改可以通过更改单个方法来表达。
有 3 种常见的扩展模式
重写 state_dict。这是扩展保存过程的最简单方法,因为它不需要理解 SavePlan 工作原理的复杂性
>>> class RenamePlanner(DefaultSavePlanner): >>> def set_up_planner(self, state_dict, is_coordinator): >>> # prefix all keys with `foo_`` >>> super().set_up_planner({"foo_" + k: v for k, v in state_dict.items()}, is_coordinator)
同时修改本地计划和查找。这在需要精细控制数据持久化方式时很有用。
>>> class FP16Planner(DefaultSavePlanner): >>> def create_local_plan(self): >>> plan = super().create_local_plan() >>> for p in plan: >>> if p.tensor_data is not None: >>> p.tensor_data.properties.dtype = torch.float16 >>> return plan >>> >>> def resolve_data(self, write_item): >>> item = super().resolve_data(write_item) >>> return item if write_item.type == WriteItemType.BYTE_IO else item.to(torch.float16)
使用全局规划步骤进行中央决策,这些决策无法由每个进程单独做出。
>>> from itertools import islice >>> from dataclasses import replace >>> class DDPLoadBalancingPlanner(DefaultSavePlanner): >>> # This uses the default local plan behavior of having all non-sharded writes in rank 0 >>> # This sample doesn't handle ShardedTensors >>> def create_global_plan(self, all_plans): >>> def chunk(it, size): >>> it = iter(it) >>> return list(iter(lambda: tuple(islice(it, size)), ())) >>> all_plans = [ >>> replace(plan, items=items) for plan, items in >>> zip(all_plans, chunk(all_plans[0].items, len(all_plans))) >>> ] >>> return super().create_global_plan(all_plans)
最后,一些规划器需要在检查点中保存额外的元数据,这可以通过让每个进程在本地计划中贡献其数据项,然后由全局规划器聚合它们来实现。
>>> class SaveExtraDataPlanner(DefaultSavePlanner): >>> def create_local_plan(self) -> SavePlan: >>> plan = super().create_local_plan() >>> return replace(plan, planner_data="per-rank-data") >>> >>> def create_global_plan(self, all_plans: List[SavePlan]) -> Tuple[List[SavePlan], Metadata]: >>> global_plan, metadata = super().create_global_plan(all_plans) >>> merged_data = [p.planner_data for p in global_plan] >>> metadata = replace(metadata, planner_data=merged_data) >>> return global_plan, metadata
- abstract create_local_plan()[source]¶
计算当前进程的保存计划。
这将被聚合并传递给 create_global_plan。规划器特定数据可以通过 SavePlan::planner_data 传递。
这在所有进程上调用。
- 返回类型
- abstract finish_plan(new_plan)[source]¶
将由 create_local_plan 创建的计划与 create_global_plan 的结果合并。
这在所有进程上调用。
- 返回类型
- abstract resolve_data(write_item)[source]¶
将
write_item
从state_dict
中转换并准备用于存储,确保幂等性和线程安全性。在
state_dict
中查找与write_item
关联的对象,并在存储层使用它之前应用任何转换(例如序列化)。在每个排名上调用多次,在最终的 SavePlan 中每个 WriteItem 至少调用一次。
此方法应该是幂等且线程安全的。StorageWriter 实现可以根据需要随意调用它。
任何分配内存的转换都应该在调用此方法时延迟执行,以减少检查点所需的峰值内存。
在返回张量时,它们可以位于任何设备或格式上,也可以是视图。存储层负责弄清楚如何保存它们。
- class torch.distributed.checkpoint.SavePlan(items: List[torch.distributed.checkpoint.planner.WriteItem], storage_data: Any = None, planner_data: Any = None)[source]¶
- class torch.distributed.checkpoint.planner.WriteItem(index, type, tensor_data=None)[source]¶
保存需要写入存储的信息的数据类。
我们提供了一个基于文件系统的存储层
- class torch.distributed.checkpoint.filesystem.FileSystemWriter(path, single_file_per_rank=True, sync_files=True, thread_count=1, per_thread_copy_ahead=10000000)[source]¶
使用文件 I/O 的 StorageWriter 的基本实现。
此实现做出了以下假设和简化
检查点路径是一个空目录或不存在的目录。
文件创建是原子的
检查点包含每个写入请求一个文件,以及一个包含序列化元数据的 .metadata 文件。
此外,我们还提供以下抽象来处理 Fsspec 存储。
- class torch.distributed.checkpoint.fsspec.FsspecWriter(path, single_file_per_rank=True, sync_files=True, thread_count=1, per_thread_copy_ahead=10000000)[source]¶
使用 FFspec 的 StorageWriter 的基本实现。
此实现做出了以下假设和简化
检查点路径是一个空目录或不存在的目录。
文件创建是原子的
检查点包含每个写入请求一个文件,以及一个包含序列化元数据的 .metadata 文件。
我们提供了 LoadPlanner 和 SavePlanner 的默认实现,它们可以处理所有 torch.distributed 结构,例如 FSDP、DDP、ShardedTensor 和 DistributedTensor。
- class torch.distributed.checkpoint.DefaultSavePlanner(flatten_state_dict=True, flatten_sharded_tensors=True, dedup_replicated_tensors=None)[source]¶
- class torch.distributed.checkpoint.DefaultLoadPlanner(flatten_state_dict=True, flatten_sharded_tensors=True)[source]¶
DefaultLoadPlanner 在 LoadPlanner 基础上添加了多个功能。
特别是它添加了以下内容
flatten_state_dict: 处理带有嵌套字典的 state_dict flatten_sharded_tensors: 适用于 2D 并行模式下的 FSDP
由于历史设计决策,FSDP 和 DDP 的状态字典可能具有不同的键或完全限定名称(例如,layer1.weight),即使原始未并行化模型相同。此外,FSDP 提供各种类型的模型状态字典,例如完整状态字典和分片状态字典。此外,优化器状态字典使用参数 ID 而不是完全限定名称来标识参数,这可能会在使用并行性(例如,流水线并行性)时导致问题。
为了解决这些挑战,我们提供了一组 API,供用户轻松管理 state_dicts。 get_model_state_dict 返回一个模型状态字典,其键与未并行化模型状态字典返回的键一致。类似地,get_optimizer_state_dict 提供优化器状态字典,其键在应用的所有并行性中保持一致。为了实现这种一致性,get_optimizer_state_dict 将参数 ID 转换为与未并行化模型状态字典中找到的完全限定名称相同的名称。
请注意,这些 API 返回的结果可以直接与 torch.distributed.checkpoint.save() 和 torch.distributed.checkpoint.load() 方法一起使用,无需任何额外的转换。
请注意,此功能处于实验阶段,API 签名将来可能会更改。
- torch.distributed.checkpoint.state_dict.get_state_dict(model, optimizers, *, submodules=None, options=None)[source]¶
返回模型的 state_dict 和优化器的 state_dict。
get_state_dict
可以处理任何由 PyTorch FSDP/fully_shard、DDP/replicate、tensor_parallel/parallelize_module 以及这些并行方式的任何组合并行的模块。get_state_dict
的主要功能是:1.) 返回模型和优化器 state_dict,这些 state_dict 可以使用不同的训练器数量和/或不同的并行方式重新分片。2.) 隐藏特定于并行性的 state_dict API。用户无需调用这些 API。3.) 对结果 state_dict 进行健全性检查。结果状态字典的键是规范的 FQN(完全限定名称)。规范的 FQN 指的是基于参数在 nn.Module 层次结构中的位置的 FQN。更具体地说,参数的规范 FQN 是当模块未由任何并行方式分布时,由
module.named_parameters()
或module.named_buffers()
返回的 FQN。由于优化器在内部使用参数 ID 来表示参数,因此在调用此 API 时,将从参数 ID 转换为规范的 FQN。get_state_dict
也可以处理未并行的模块。在这种情况下,get_state_dict
只执行一个功能 - 将优化器参数 ID 转换为规范的 FQN。示例
>>> import torch >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP >>> from torch.nn.parallel import DistributedDataParallel as DDP >>> from torch.distributed.checkpoint.state_dict import get_state_dict
>>> fsdp_model = FSDP(copy.deepcopy(model)) >>> fsdp_optim = torch.optim.Adam(model.parameters(), lr=1e-3) >>> ddp_model = DDP(copy.deepcopy(model)) >>> ddp_optim = torch.optim.Adam(model.parameters(), lr=1e-3)
>>> ddp_state_dict, ddp_optim_state_dict = get_state_dict(ddp_model, ddp_optim) >>> fsdp_state_dict, fsdp_optim_state_dict = get_state_dict(fsdp_model, fsdp_optim)
>>> # if we simply call ddp_model.state_dict() and fsdp_model.state_dict(), >>> # the asserts will fail. >>> assert ddp_state_dict == fsdp_state_dict >>> assert ddp_optim_state == fsdp_optim_state_dict
- 参数
- 返回值
Tuple
包含模型 state_dict 和优化器 state_dict。- 返回类型
- torch.distributed.checkpoint.state_dict.get_model_state_dict(model, *, submodules=None, options=None)[source]¶
返回
model
的模型 state_dict。有关详细用法,请参阅
get_state_dict
。
- torch.distributed.checkpoint.state_dict.get_optimizer_state_dict(model, optimizers, *, submodules=None, options=None)[source]¶
返回优化器的组合 state_dict。
有关详细用法,请参阅
get_state_dict
。- 参数
- 返回值
用于
optimizers
的 state_dict。- 返回类型
OptimizerStateType
- torch.distributed.checkpoint.state_dict.set_state_dict(model, optimizers, *, model_state_dict, optim_state_dict, options=None)[source]¶
加载模型 state_dict 和优化器 state_dict。
get_state_dict
的对应方法,用于将 state_dict 设置到模型和优化器。给定的model_state_dict
和optim_state_dict
不必由get_state_dict
返回,但必须满足以下要求:1) 所有 FQN 都是get_state_dict
中定义的规范 FQN,2) 如果张量被分片,它必须是 ShardedTensor 或 DTensor,3) 优化器 state_dict 不能包含参数 ID;键应该是规范的 FQN。- 参数
model (nn.Module) – 模型的 nn.Module。
optimizers (Union[Optimizer, Iterable[Optimizer]]) – 用于优化
model
的优化器。model_state_dict (Dict[str, ValueType]) – (Union[Dict[nn.Module, Dict[str, ValueType]], Dict[str, ValueType]]): 要加载的模型 state_dict。如果
model_state_dict
的键是 nn.Module,则键是model
的子模块,值应该是子模块的 state_dict。加载 state_dict 时,子模块的前缀将附加到 state_dict。optim_state_dict (OptimizerStateType) – OptimizerStateType: 要加载的优化器 state_dict。
options (StateDictOptions) – 控制如何加载模型 state_dict 和优化器 state_dict 的选项。有关详细信息,请参阅 StateDictOptions。
- 返回值
missing_keys 是一个包含模型 state_dict 中缺失键的字符串列表。
unexpected_keys 是一个包含模型 state_dict 中意外键的字符串列表。
- 返回类型
NamedTuple
包含missing_keys
和unexpected_keys
字段
- torch.distributed.checkpoint.state_dict.set_model_state_dict(model, model_state_dict, *, options=None)[source]¶
加载模型 state_dict。
这是
get_model_state_dict
的对应函数,用于将 state_dict 设置到模型。有关详细用法,请参阅set_state_dict
。- 参数
model (nn.Module) – 模型的 nn.Module。
model_state_dict (Dict[str, ValueType]) – (Dict[str, ValueType]): 要加载的模型 state_dict。如果
model_state_dict
的键是 nn.Module,则该键是model
的子模块,而值应该是该子模块的 state_dict。加载 state_dict 时,子模块的前缀将附加到 state_dict。options (StateDictOptions) – 控制如何加载模型 state_dict 和优化器 state_dict 的选项。有关详细信息,请参阅 StateDictOptions。
- 返回值
missing_keys 是一个包含缺失键的字符串列表
unexpected_keys 是一个包含意外键的字符串列表
- 返回类型
NamedTuple
包含missing_keys
和unexpected_keys
字段
- torch.distributed.checkpoint.state_dict.set_optimizer_state_dict(model, optimizers, *, optim_state_dict, options=None)[source]¶
加载优化器 state_dict。
这是
get_optimizer_state_dict
的对应函数,用于将 state_dict 设置到优化器。有关详细用法,请参阅set_state_dict
。- 参数
model (nn.Module) – 模型的 nn.Module。
optimizers (Union[Optimizer, Iterable[Optimizer]]) – 用于优化
model
的优化器。optim_state_dict (OptimizerStateType) – OptimizerStateType: 要加载的优化器 state_dict。
options (StateDictOptions) – 控制如何加载模型 state_dict 和优化器 state_dict 的选项。有关详细信息,请参阅 StateDictOptions。
- 返回值
无
- 返回类型
无
- class torch.distributed.checkpoint.state_dict.StateDictOptions(full_state_dict=False, cpu_offload=False, ignore_frozen_params=False, keep_submodule_prefixes=True, strict=True)[source]¶
此数据类指定了 get_state_dict/set_state_dict 的工作方式。
full_state_dict
: 如果设置为 True,则返回的 state_dict 中的所有张量都将被收集。返回的 state_dict 中不会包含任何 ShardedTensor 和 DTensor。cpu_offload
: 将所有张量卸载到 CPU。为了防止 CPU OOM,如果full_state_dict
也为 True,则只有 rank0 会获得 state_dict,而所有其他 rank 会获得空 state_dict。ignore_frozen_params
: 如果值为 True,则返回的 state_dict 不会包含任何冻结参数 -requires_grad
为 False。默认值为 False。keep_submodule_prefixes
: 当submodules
不为 None 时,此选项指示是否保留 state_dict 键中的子模块前缀。例如,如果子模块为module.pretrain
,参数的完整 FQN 为pretrain.layer1.weight
。当此选项为 True 时,参数在返回的 state_dict 中的键将为pretrain.layer1.weight
。如果选项为 False,则键将为layer1.weight
。请注意,如果keep_submodule_prefixes
为 False,则可能存在冲突的 FQN,因此submodules
中应该只有一个子模块。strict
: 当set_state_dict
调用 model.load_state_dict() 时,strict
选项。默认值为 False。
对于习惯使用和共享 torch.save 格式模型的用户,提供以下方法,用于在格式之间进行离线转换。
- torch.distributed.checkpoint.format_utils.dcp_to_torch_save(dcp_checkpoint_dir, torch_save_path)[source]¶
给定包含 DCP 检查点的目录,此函数会将其转换为 Torch 保存文件。
- 参数
警告
为了避免 OOM,建议仅在一个排名上运行此函数。
- torch.distributed.checkpoint.format_utils.torch_save_to_dcp(torch_save_path, dcp_checkpoint_dir)[source]¶
给定 torch 保存文件的路径,将其转换为 DCP 检查点。
- 参数
警告
为了避免 OOM,建议仅在一个排名上运行此函数。
以下类也可以用于在线加载和重新分片 torch.save 格式的模型。
- class torch.distributed.checkpoint.format_utils.BroadcastingTorchSaveReader(checkpoint_id=None, coordinator_rank=0)[source]¶
用于读取 Torch 保存文件的 StorageReader。此读取器将在协调器排名上读取整个检查点,然后将每个张量广播并分片到所有排名。
. 注意:旨在与 DynamicMetaLoadPlanner 一起使用
警告
当前实现仅支持加载张量。
>>> sd = {"mode": model} >>> dcp.load( >>> sd, >>> storage_reader=BroadcastingTorchSaveReader(), >>> planner=DynamicMetaLoadPlanner(), >>> checkpoint_id="path_to_model.pt" >>> )
- class torch.distributed.checkpoint.format_utils.DynamicMetaLoadPlanner(flatten_state_dict=True, flatten_sharded_tensors=True)[source]¶
扩展 DefaultLoadPlanner,根据传入的状态字典创建新的 Metadata 对象,避免从磁盘读取元数据。这在读取没有元数据文件的格式(如 Torch Save 文件)时很有用。
. 注意:旨在与 BroadcastingTorchSaveReader 一起使用
警告
当前实现仅支持加载张量。
>>> sd = {"mode": model} >>> dcp.load( >>> sd, >>> storage_reader=BroadcastingTorchSaveReader(), >>> planner=DynamicMetaLoadPlanner(), >>> checkpoint_id="path_to_model.pt" >>> )