入门指南¶
Snapshot
是 TorchSnapshot 的核心 API。该类表示持久存储在存储中的应用程序状态。用户可以通过 Snapshot.take()
(也称为保存检查点)获取应用程序的快照,并通过 Snapshot.restore()
(也称为加载检查点)从快照恢复应用程序的状态。
安装¶
有关安装说明,请参阅 README.md。
描述应用程序状态¶
在使用 Snapshot
保存或恢复应用程序状态之前,用户需要**描述应用程序状态**。这是通过创建一个字典来完成的,该字典包含用户希望捕获为应用程序状态的所有**状态对象**。任何公开 .state_dict()
和 .load_state_dict()
的对象都被视为状态对象。常见的 PyTorch 对象(例如 Module
、Optimizer
和 LR 调度器)都符合状态对象的条件,可以直接捕获。不符合此要求的对象可以通过 StateDict
捕获。
from torchsnapshot import StateDict
app_state = {
"model": model,
"optimizer": optimizer,
"extra_state": StateDict(iterations=0)
}
获取快照¶
描述应用程序状态后,用户可以通过 Snapshot.take()
获取应用程序的快照,该快照会将应用程序状态持久存储在用户指定的路径中,并返回对该快照的引用。
TorchSnapshot 开箱即用地提供了与常用云对象存储的高性能且可靠的集成。用户可以通过在路径前添加 URI 前缀来选择不同的存储后端(例如,s3://
用于 S3,gs://
用于 Google Cloud Storage)。默认情况下,前缀为 fs://
,这表示该路径是文件系统位置。
from torchsnapshot import Snapshot
# Persist the application state to local FS or network FS
snapshot = Snapshot.take(path="/path/to/my/snapshot", app_state=app_state)
# Alternatively, persist the application state to S3
snapshot = Snapshot.take(
path="s3://bucket/path/to/my/snapshot",
app_state=app_state
)
注意
不要在使用 TorchSnapshot 保存 GPU 张量之前将其移动到 CPU。TorchSnapshot 实现了各种优化,用于提高 GPU 到存储传输的吞吐量并减少主机内存的使用。手动将 GPU 张量移动到 CPU 会降低吞吐量,并增加“内存不足”问题的可能性。
从快照恢复¶
要从快照恢复,用户首先需要获取对该快照的引用。如前所述,在获取快照的过程中,Snapshot.take()
会返回对该快照的引用。在另一个进程中(这对于恢复更为常见),可以通过使用快照路径创建 Snapshot
对象来获取引用。
要从快照恢复应用程序状态,请使用应用程序状态调用 Snapshot.restore()
from torchsnapshot import Snapshot
snapshot = Snapshot(path="/path/to/my/snapshot")
snapshot.restore(app_state=app_state)
注意
Snapshot.restore()
尽可能地就地恢复状态对象,以避免创建不必要的中间状态副本。
分布式快照¶
TorchSnapshot 将分布式应用程序作为一等公民提供支持。要获取分布式应用程序的快照,只需在所有排名上同时调用 Snapshot.take()
(类似于调用 torch.distributed API)。持久化的应用程序状态将被组织为单个快照。
TorchSnapshot 通过在所有排名之间均匀分配写入工作负载,极大地提高了分布式数据并行应用程序的检查点性能(基准测试)。速度的提高是由于更好地利用了 GPU 复制单元和存储 I/O 并行化。
ddp_model = DistributedDataParallel(model)
app_state = {"model": ddp_model}
snapshot = Snapshot.take(path="/path/to/my/snapshot", app_state=app_state)
快照内容访问¶
即使快照存储在云对象存储中,也可以在不获取整个快照的情况下高效地访问快照中的对象。这对于传输学习和后处理模型非常有用,这些模型太大而无法容纳在单个主机/设备中。
snapshot = Snapshot(path="/path/to/my/snapshot")
# Available object paths can be queried with snapshot.get_manifest()
layer_0_weight = snapshot.read_object(path="0/model/layer_0.weight")
异步获取快照¶
当主机内存充足时,用户可以使用 Snapshot.async_take()
来允许在所有存储 I/O 完成之前恢复训练。一旦 Snapshot.async_take()
在主机 RAM 中暂存快照内容并在后台安排存储 I/O,它就会立即返回。这可以大大减少为检查点阻塞的时间,尤其是在底层存储速度较慢的情况下。
pending_snapshot = Snapshot.async_take(
path="/path/to/my/snapshot",
app_state=app_state,
)
# Users can query the pending snapshot's status
if pending_snapshot.done():
...
# ... or wait for the pending snapshot to complete
snapshot = pending_snapshot.wait()
注意
尽管 API 名称中带有“async”,但通过 Snapshot.async_take()
创建的快照是一致且确定的。
可重复性¶
TorchSnapshot 提供了一个名为 RNGState
的实用程序来帮助用户管理可重复性。如果在应用程序状态中捕获了 RNGState
对象,TorchSnapshot 确保在从快照恢复后,全局 RNG 状态将设置为与获取快照后相同的值。
from torchsnapshot import Snapshot, RNGState
app_state = {"model": model, "optimizer": optimizer, "rng_state": RNGState()}
snapshot = Snapshot.take(path="/path/to/my/snapshot", app_state=app_state)
# global RNG state => {x}
# In the same process or in another process
snapshot.restore(app_state=app_state)
# global RNG state => {x}
弹性(实验性)¶
只要快照只包含**复制**对象或**分片**对象,分布式应用程序就可以从使用不同世界大小获取的快照恢复
复制对象是指 (1) 在所有排名下以相同的 state dict 密钥存在,并且 (2) 在
Snapshot.take()
期间在所有排名上都持有相同值的 对象。复制对象的一个例子是DistributedDataParallel
的 state dict 中的张量。当恢复复制对象时,所有新加入的排名都可以使用它。分片对象是指其状态在多个排名之间分片的 对象。目前,唯一支持的分片对象是
ShardedTensor
。不同排名上具有相同 state dict 密钥的ShardedTensor
被视为同一个全局张量的一部分。当全局张量的分片由于恢复时世界大小的变化而发生变化时,全局张量将自动正确地重新分片。
注意
如果一个对象既不是复制的也不是分片的,则在恢复时只能由保存排名加载它。这可以防止意外地将非弹性模型视为弹性模型。