快捷方式

分布式检查点

PyTorch/XLA SPMD 通过专用的 Planner 实例与 torch.distributed.checkpoint 库兼容。用户可以通过此通用接口同步保存和加载检查点。

SPMDSavePlanner 和 SPMDLoadPlanner (src) 类使得 saveload 函数能够直接操作 XLAShardedTensor 的分片,从而在 SPMD 训练中实现分布式检查点的所有优点。

以下是同步分布式检查点 API 的演示

import torch.distributed.checkpoint as dist_cp
import torch_xla.experimental.distributed_checkpoint as xc

# Saving a state_dict
state_dict = {
    "model": model.state_dict(),
    "optim": optim.state_dict(),
}

dist_cp.save(
    state_dict=state_dict,
    storage_writer=dist_cp.FileSystemWriter(CHECKPOINT_DIR),
    planner=xc.SPMDSavePlanner(),
)
...

# Loading the model's state_dict from the checkpoint. The model should
# already be on the XLA device and have the desired sharding applied.
state_dict = {
    "model": model.state_dict(),
}

dist_cp.load(
    state_dict=state_dict,
    storage_reader=dist_cp.FileSystemReader(CHECKPOINT_DIR),
    planner=xc.SPMDLoadPlanner(),
)
model.load_state_dict(state_dict["model"])

CheckpointManager

实验性的 CheckpointManager 接口在 torch.distributed.checkpoint 函数之上提供了一个更高级别的 API,以实现几个关键特性

  • 受管理的检查点: 通过 CheckpointManager 获取的每个检查点都通过获取它的步骤进行标识。所有跟踪的步骤都可以通过 CheckpointManager.all_steps 方法访问,并且可以使用 CheckpointManager.restore 恢复任何跟踪的步骤。

  • 异步检查点: 通过 CheckpointManager.save_async API 获取的检查点会异步写入持久存储,以在检查点期间不阻塞训练。输入的分片 state_dict 在检查点被分派到后台线程之前会先移动到 CPU。

  • 在抢占时自动检查点: 在 Cloud TPU 上,可以检测到抢占并在进程终止前获取检查点。要使用此功能,请确保您的 TPU 是通过启用了 Autocheckpointing enabled 的 QueuedResource 提供的,并在构建 CheckpointManager 时确保设置了 chkpt_on_preemption 参数(此选项默认启用)。

  • FSSpec 支持: CheckpointManager 使用 fsspec 存储后端,可以直接将检查点保存到任何兼容 fsspec 的文件系统,包括 GCS。

CheckpointManager 的使用示例如下

from torch_xla.experimental.distributed_checkpoint import CheckpointManager, prime_optimizer

# Create a CheckpointManager to checkpoint every 10 steps into GCS.
chkpt_mgr = CheckpointManager('gs://my-bucket/my-experiment', 10)

# Select a checkpoint to restore from, and restore if applicable
tracked_steps = chkpt_mgr.all_steps()
if tracked_steps:
    # Choose the highest step
    best_step = max(tracked_steps)
    # Before restoring the checkpoint, the optimizer state must be primed
    # to allow state to be loaded into it.
    prime_optimizer(optim)
    state_dict = {'model': model.state_dict(), 'optim': optim.state_dict()}
    chkpt_mgr.restore(best_step, state_dict)
    model.load_state_dict(state_dict['model'])
    optim.load_state_dict(state_dict['optim'])

# Call `save` or `save_async` every step within the train loop. These methods
# return True when a checkpoint is taken.
for step, data in enumerate(dataloader):
    ...
    state_dict = {'model': model.state_dict(), 'optim': optim.state_dict()}
    if chkpt_mgr.save_async(step, state_dict):
        print(f'Checkpoint taken at step {step}')

恢复优化器状态

在分布式检查点中,state_dicts 会原地加载,并且只加载检查点中所需的分片。由于优化器状态是延迟创建的,所以在第一次 optimizer.step 调用之前状态不会出现,并且尝试加载未 primed 的优化器将会失败。

为此提供了实用方法 prime_optimizer:它通过将所有梯度设置为零并调用 optimizer.step 来运行一个假的训练步骤。这是**一个破坏性方法**,会触及模型参数和优化器状态,因此只能在恢复之前立即调用。

进程组

要使用诸如分布式检查点之类的 torch.distributed API,需要一个进程组。在 SPMD 模式下,不支持 xla 后端,因为编译器负责所有的集合操作。

相反,必须使用诸如 gloo 的 CPU 进程组。在 TPU 上,仍然支持 xla:// init_method 来发现 master IP、全局 world size 和 host rank。以下是一个初始化示例

import torch.distributed as dist
# Import to register the `xla://` init_method
import torch_xla.distributed.xla_backend
import torch_xla.runtime as xr

xr.use_spmd()

# The `xla://` init_method will automatically discover master worker IP, rank,
# and global world size without requiring environment configuration on TPUs.
dist.init_process_group('gloo', init_method='xla://')

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

获取针对初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得解答

查看资源