快捷方式

API 参考

Snapshot(path: str, pg: 可选[ProcessGroup] = , storage_options: 可选[字典[str, 任何]] = )

创建对现有快照的引用。

参数:
  • 路径 (str) – 快照的路径。这应该与拍摄快照时用于 Snapshot.take()path 参数相同。

  • pg (ProcessGroup可选) – Snapshot.restore() 参与者的进程组。如果为无,则将使用默认进程组。

  • storage_options (字典[str任何]可选) – 用于存储插件的其他关键字选项。有关自定义,请参阅每个存储插件的文档。

类方法 take(path: str, app_state: 字典[str, T], pg: 可选[ProcessGroup] = , replicated: 可选[列表[str]] = , storage_options: 可选[字典[str, 任何]] = , _custom_tensor_prepare_func: 可选[可调用[[str, 张量, 布尔], 张量]] = ) Snapshot

拍摄应用程序状态的快照。

参数:
  • app_state (字典[str有状态]) – 要持久化的应用程序状态。它采用字典的形式,键是用户定义的字符串,值是有状态对象。有状态对象是公开 .state_dict().load_state_dict() 方法的对象。常见的 PyTorch 对象,例如 torch.nn.Moduletorch.optim.Optimizer 和 LR 调度器,都属于有状态对象。

  • 路径 (str) –

    保存快照的位置。 path 可以有一个 URI 前缀(例如 s3://),用于指定存储后端。如果没有提供 URI 前缀,则假定 path 是文件系统位置。对于分布式快照,如果参与的各个进程的 path 不一致,则将使用进程 0 指定的值。对于多主机快照,path 需要是所有主机都可以访问的位置。

    注意

    path 不得 指向现有快照。

  • pg (ProcessGroup可选) – Snapshot.take() 参与者的进程组。如果为无,则将使用默认进程组。

  • replicated (列表[str]可选) –

    用于将检查点内容标记为已复制的全局模式。匹配的对象将被去重并在各个进程之间进行负载平衡。

    注意

    对于 DistributedDataParallel,将自动推断复制属性。仅当您的模型具有完全复制的状态但不使用 DistributedDataParallel 时,才指定此参数。

  • storage_options (字典[str任何]可选) – 用于存储插件的其他关键字选项。有关自定义,请参阅每个存储插件的文档。

返回值:

新拍摄的快照。

classmethod async_take(path: str, app_state: Dict[str, T], pg: Optional[ProcessGroup] = None, replicated: Optional[List[str]] = None, storage_options: Optional[Dict[str, Any]] = None, _custom_tensor_prepare_func: Optional[Callable[[str, Tensor, bool], Tensor]] = None) PendingSnapshot

异步地从应用程序状态获取快照。

此函数与 Snapshot.take() 相同,不同之处在于它会尽早返回并在后台执行尽可能多的 I/O 操作,从而允许训练尽早恢复。

参数:
  • app_state (Dict[str, Stateful]) – 与 Snapshot.take()app_state 参数相同。

  • path (str) – 与 Snapshot.take()path 参数相同。

  • pg (ProcessGroup, 可选) – 与 Snapshot.take()pg 参数相同。

  • replicated (List[str], 可选) – 与 Snapshot.take()replicated 参数相同。

  • storage_options (Dict[str, Any], 可选) – 与 Snapshot.take()storage_options 参数相同。

返回值:

挂起的快照的句柄。 该句柄公开了一个 .done() 方法用于查询进度,以及一个 .wait() 方法用于等待快照完成。

restore(app_state: Dict[str, T]) None

从快照恢复应用程序状态。

参数:

app_state (Dict[str, Stateful]) – 要恢复的应用程序状态。 app_state 需要与获取快照时 Snapshot.take() 使用的 app_state 相同或为其子集。

read_object(path: str, obj_out: Optional[T] = None, memory_budget_bytes: Optional[int] = None) T

从快照内容中读取对象。

参数:
  • path (str) – 快照中目标对象的路径。 path 等同于快照清单中目标对象的键,可以通过 Snapshot.get_manifest() 获取。

  • obj_out (Any, 可选) –

    如果指定,并且对象类型支持就地加载,则将对象就地加载到 obj_out 中。 否则,obj_out 将被忽略。

    注意

    当目标对象是 ShardedTensor 并且 obj_out 为 None 时,将返回分片张量的 CPU、完整张量版本。

  • memory_budget_bytes (int, 可选) – 如果指定,读取操作将使临时内存缓冲区大小低于此阈值。

返回值:

从快照内容中读取的对象。

get_manifest() Dict[str, Entry]

返回快照清单。

字典中的每个条目对应于快照中的一个对象,键是对象的逻辑路径,值是描述对象的元数据。 对于分布式快照,清单包含所有进程保存的对象的条目。

返回值:

快照清单。

class StateDict(dict=None, /, **kwargs)

一个公开 .state_dict().load_state_dict() 方法的字典。

它可以用来捕获不公开 .state_dict().load_state_dict() 方法的对象(例如,张量、Python 原始类型)作为应用程序状态的一部分。

class RNGState

用于保存和恢复全局 RNG 状态的特殊状态对象。

当在应用程序状态中捕获时,可以保证全局 RNG 状态在从快照恢复后与获取快照后的值相同。

示例

>>> Snapshot.take(
>>>     path="foo/bar",
>>>     app_state={"rng_state": RNGState()},
>>> )
>>> after_take = torch.rand(1)

>>> # In the same process or in another process
>>> snapshot = Snapshot(path="foo/bar")
>>> snapshot.restore(app_state)
>>> after_restore = torch.rand(1)

>>> torch.testing.assert_close(after_take, after_restore)

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的答案

查看资源