API 参考¶
- 类 Snapshot(path: str, pg: 可选[ProcessGroup] = 无, storage_options: 可选[字典[str, 任何]] = 无)¶
创建对现有快照的引用。
- 参数::
路径 (str) – 快照的路径。这应该与拍摄快照时用于
Snapshot.take()
的path
参数相同。pg (ProcessGroup,可选) –
Snapshot.restore()
参与者的进程组。如果为无,则将使用默认进程组。storage_options (字典[str,任何],可选) – 用于存储插件的其他关键字选项。有关自定义,请参阅每个存储插件的文档。
- 类方法 take(path: str, app_state: 字典[str, T], pg: 可选[ProcessGroup] = 无, replicated: 可选[列表[str]] = 无, storage_options: 可选[字典[str, 任何]] = 无, _custom_tensor_prepare_func: 可选[可调用[[str, 张量, 布尔], 张量]] = 无) Snapshot ¶
拍摄应用程序状态的快照。
- 参数::
app_state (字典[str,有状态]) – 要持久化的应用程序状态。它采用字典的形式,键是用户定义的字符串,值是有状态对象。有状态对象是公开
.state_dict()
和.load_state_dict()
方法的对象。常见的 PyTorch 对象,例如torch.nn.Module
、torch.optim.Optimizer
和 LR 调度器,都属于有状态对象。路径 (str) –
保存快照的位置。
path
可以有一个 URI 前缀(例如s3://
),用于指定存储后端。如果没有提供 URI 前缀,则假定path
是文件系统位置。对于分布式快照,如果参与的各个进程的path
不一致,则将使用进程 0 指定的值。对于多主机快照,path
需要是所有主机都可以访问的位置。注意
path
不得 指向现有快照。pg (ProcessGroup,可选) –
Snapshot.take()
参与者的进程组。如果为无,则将使用默认进程组。replicated (列表[str],可选) –
用于将检查点内容标记为已复制的全局模式。匹配的对象将被去重并在各个进程之间进行负载平衡。
注意
对于
DistributedDataParallel
,将自动推断复制属性。仅当您的模型具有完全复制的状态但不使用DistributedDataParallel
时,才指定此参数。storage_options (字典[str,任何],可选) – 用于存储插件的其他关键字选项。有关自定义,请参阅每个存储插件的文档。
- 返回值::
新拍摄的快照。
- classmethod async_take(path: str, app_state: Dict[str, T], pg: Optional[ProcessGroup] = None, replicated: Optional[List[str]] = None, storage_options: Optional[Dict[str, Any]] = None, _custom_tensor_prepare_func: Optional[Callable[[str, Tensor, bool], Tensor]] = None) PendingSnapshot ¶
异步地从应用程序状态获取快照。
此函数与
Snapshot.take()
相同,不同之处在于它会尽早返回并在后台执行尽可能多的 I/O 操作,从而允许训练尽早恢复。- 参数::
app_state (Dict[str, Stateful]) – 与
Snapshot.take()
的app_state
参数相同。path (str) – 与
Snapshot.take()
的path
参数相同。pg (ProcessGroup, 可选) – 与
Snapshot.take()
的pg
参数相同。replicated (List[str], 可选) – 与
Snapshot.take()
的replicated
参数相同。storage_options (Dict[str, Any], 可选) – 与
Snapshot.take()
的storage_options
参数相同。
- 返回值::
挂起的快照的句柄。 该句柄公开了一个
.done()
方法用于查询进度,以及一个.wait()
方法用于等待快照完成。
- restore(app_state: Dict[str, T]) None ¶
从快照恢复应用程序状态。
- 参数::
app_state (Dict[str, Stateful]) – 要恢复的应用程序状态。
app_state
需要与获取快照时Snapshot.take()
使用的app_state
相同或为其子集。
- read_object(path: str, obj_out: Optional[T] = None, memory_budget_bytes: Optional[int] = None) T ¶
从快照内容中读取对象。
- 参数::
path (str) – 快照中目标对象的路径。
path
等同于快照清单中目标对象的键,可以通过Snapshot.get_manifest()
获取。obj_out (Any, 可选) –
如果指定,并且对象类型支持就地加载,则将对象就地加载到
obj_out
中。 否则,obj_out
将被忽略。注意
当目标对象是
ShardedTensor
并且obj_out
为 None 时,将返回分片张量的 CPU、完整张量版本。memory_budget_bytes (int, 可选) – 如果指定,读取操作将使临时内存缓冲区大小低于此阈值。
- 返回值::
从快照内容中读取的对象。
- get_manifest() Dict[str, Entry] ¶
返回快照清单。
字典中的每个条目对应于快照中的一个对象,键是对象的逻辑路径,值是描述对象的元数据。 对于分布式快照,清单包含所有进程保存的对象的条目。
- 返回值::
快照清单。
- class StateDict(dict=None, /, **kwargs)¶
一个公开
.state_dict()
和.load_state_dict()
方法的字典。它可以用来捕获不公开
.state_dict()
和.load_state_dict()
方法的对象(例如,张量、Python 原始类型)作为应用程序状态的一部分。
- class RNGState¶
用于保存和恢复全局 RNG 状态的特殊状态对象。
当在应用程序状态中捕获时,可以保证全局 RNG 状态在从快照恢复后与获取快照后的值相同。
示例
>>> Snapshot.take( >>> path="foo/bar", >>> app_state={"rng_state": RNGState()}, >>> ) >>> after_take = torch.rand(1) >>> # In the same process or in another process >>> snapshot = Snapshot(path="foo/bar") >>> snapshot.restore(app_state) >>> after_restore = torch.rand(1) >>> torch.testing.assert_close(after_take, after_restore)