快捷方式

入门指南

Snapshot 是 TorchSnapshot 的核心 API。该类表示持久存储在存储中的应用程序状态。用户可以通过 Snapshot.take()(也称为保存检查点)获取应用程序的快照,并通过 Snapshot.restore()(也称为加载检查点)从快照恢复应用程序的状态。

安装

有关安装说明,请参阅 README.md

描述应用程序状态

在使用 Snapshot 保存或恢复应用程序状态之前,用户需要**描述应用程序状态**。这是通过创建一个字典来完成的,该字典包含用户希望捕获为应用程序状态的所有**状态对象**。任何公开 .state_dict().load_state_dict() 的对象都被视为状态对象。常见的 PyTorch 对象(例如 ModuleOptimizer 和 LR 调度器)都符合状态对象的条件,可以直接捕获。不符合此要求的对象可以通过 StateDict 捕获。

from torchsnapshot import StateDict

app_state = {
    "model": model,
    "optimizer": optimizer,
    "extra_state": StateDict(iterations=0)
}

获取快照

描述应用程序状态后,用户可以通过 Snapshot.take() 获取应用程序的快照,该快照会将应用程序状态持久存储在用户指定的路径中,并返回对该快照的引用。

TorchSnapshot 开箱即用地提供了与常用云对象存储的高性能且可靠的集成。用户可以通过在路径前添加 URI 前缀来选择不同的存储后端(例如,s3:// 用于 S3gs:// 用于 Google Cloud Storage)。默认情况下,前缀为 fs://,这表示该路径是文件系统位置。

from torchsnapshot import Snapshot

# Persist the application state to local FS or network FS
snapshot = Snapshot.take(path="/path/to/my/snapshot", app_state=app_state)

# Alternatively, persist the application state to S3
snapshot = Snapshot.take(
    path="s3://bucket/path/to/my/snapshot",
    app_state=app_state
)

注意

不要在使用 TorchSnapshot 保存 GPU 张量之前将其移动到 CPU。TorchSnapshot 实现了各种优化,用于提高 GPU 到存储传输的吞吐量并减少主机内存的使用。手动将 GPU 张量移动到 CPU 会降低吞吐量,并增加“内存不足”问题的可能性。

从快照恢复

要从快照恢复,用户首先需要获取对该快照的引用。如前所述,在获取快照的过程中,Snapshot.take() 会返回对该快照的引用。在另一个进程中(这对于恢复更为常见),可以通过使用快照路径创建 Snapshot 对象来获取引用。

要从快照恢复应用程序状态,请使用应用程序状态调用 Snapshot.restore()

from torchsnapshot import Snapshot

snapshot = Snapshot(path="/path/to/my/snapshot")
snapshot.restore(app_state=app_state)

注意

Snapshot.restore() 尽可能地就地恢复状态对象,以避免创建不必要的中间状态副本。

分布式快照

TorchSnapshot 将分布式应用程序作为一等公民提供支持。要获取分布式应用程序的快照,只需在所有排名上同时调用 Snapshot.take()(类似于调用 torch.distributed API)。持久化的应用程序状态将被组织为单个快照。

TorchSnapshot 通过在所有排名之间均匀分配写入工作负载,极大地提高了分布式数据并行应用程序的检查点性能(基准测试)。速度的提高是由于更好地利用了 GPU 复制单元和存储 I/O 并行化。

ddp_model = DistributedDataParallel(model)
app_state = {"model": ddp_model}
snapshot = Snapshot.take(path="/path/to/my/snapshot", app_state=app_state)

快照内容访问

即使快照存储在云对象存储中,也可以在不获取整个快照的情况下高效地访问快照中的对象。这对于传输学习和后处理模型非常有用,这些模型太大而无法容纳在单个主机/设备中。

snapshot = Snapshot(path="/path/to/my/snapshot")

# Available object paths can be queried with snapshot.get_manifest()
layer_0_weight = snapshot.read_object(path="0/model/layer_0.weight")

异步获取快照

当主机内存充足时,用户可以使用 Snapshot.async_take() 来允许在所有存储 I/O 完成之前恢复训练。一旦 Snapshot.async_take() 在主机 RAM 中暂存快照内容并在后台安排存储 I/O,它就会立即返回。这可以大大减少为检查点阻塞的时间,尤其是在底层存储速度较慢的情况下。

pending_snapshot = Snapshot.async_take(
    path="/path/to/my/snapshot",
    app_state=app_state,
)

# Users can query the pending snapshot's status
if pending_snapshot.done():
    ...

# ... or wait for the pending snapshot to complete
snapshot = pending_snapshot.wait()

注意

尽管 API 名称中带有“async”,但通过 Snapshot.async_take() 创建的快照是一致且确定的。

可重复性

TorchSnapshot 提供了一个名为 RNGState 的实用程序来帮助用户管理可重复性。如果在应用程序状态中捕获了 RNGState 对象,TorchSnapshot 确保在从快照恢复后,全局 RNG 状态将设置为与获取快照后相同的值。

from torchsnapshot import Snapshot, RNGState

app_state = {"model": model, "optimizer": optimizer, "rng_state": RNGState()}
snapshot = Snapshot.take(path="/path/to/my/snapshot", app_state=app_state)
# global RNG state => {x}

# In the same process or in another process
snapshot.restore(app_state=app_state)
# global RNG state => {x}

弹性(实验性)

只要快照只包含**复制**对象或**分片**对象,分布式应用程序就可以从使用不同世界大小获取的快照恢复

  • 复制对象是指 (1) 在所有排名下以相同的 state dict 密钥存在,并且 (2) 在 Snapshot.take() 期间在所有排名上都持有相同值的 对象。复制对象的一个例子是 DistributedDataParallel 的 state dict 中的张量。当恢复复制对象时,所有新加入的排名都可以使用它。

  • 分片对象是指其状态在多个排名之间分片的 对象。目前,唯一支持的分片对象是 ShardedTensor。不同排名上具有相同 state dict 密钥的 ShardedTensor 被视为同一个全局张量的一部分。当全局张量的分片由于恢复时世界大小的变化而发生变化时,全局张量将自动正确地重新分片。

注意

如果一个对象既不是复制的也不是分片的,则在恢复时只能由保存排名加载它。这可以防止意外地将非弹性模型视为弹性模型。

文档

访问 PyTorch 的综合开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源