摘要:借助 PyTorch 分布式新增的异步检查点(Asynchronous Checkpointing)功能(该功能根据 IBM 的反馈开发而成),我们展示了 IBM 研究团队如何实现检查点保存,并将有效检查点时间缩短了 10-20 倍。示例:7B 模型进行检查点保存的“停机时间”从平均 148.8 秒缩短至 6.3 秒,速度提升了 23.62 倍。
这意味着在同样的 24 小时周期内,在保持稳健检查点保存的同时,可以获得更多的净训练进度,或者通过更频繁的检查点保存来缩短恢复窗口/时间。
在本文中,我们将展示使异步检查点成为可能的用法代码和架构,以及经由 IBM 研究团队验证的计时结果。

模型检查点保存是大型模型训练的重要组成部分,但它是一个昂贵的过程,因为每个检查点进程都需要阻塞训练进度以保存最新的模型权重。然而,不保存检查点或降低保存频率可能会导致训练进度的严重损失。例如,死锁、掉队者(straggler)和 GPU 错误等故障需要重新启动训练进程。为了从故障中恢复,所有(训练)工作节点必须停止其训练进程,并从最后保存的检查点重新开始。
因此,故障稳健性与训练进度之间的内在张力是一种权衡,但现在有了异步检查点功能,PyTorch Distributed 能够显著降低这种张力,并在对整体训练时间影响最小的情况下实现频繁的检查点保存。
背景方面,就在大约一年前,我们展示了分布式检查点保存是如何在原有 torch.save() 功能的基础上大幅加快检查点保存时间的。正如 IBM Research 所指出的,torch.save 对单个 11B 模型进行检查点保存可能需要长达 30 分钟(PyTorch 1.13 版本)。
随着分布式检查点技术的进步,对于高达 30B 的模型,检查点保存可以在 4 分钟内完成。
而通过异步检查点,因检查点保存而损失的训练时间现在已缩短至 30 秒以内,通常甚至短至 6 秒。
需要明确的是,异步检查点并不会压缩实际序列化检查点的耗时(如前次更新所述)。相反,它将最终的检查点保存过程从关键路径(移至 CPU 线程)移出,从而允许 GPU 训练在独立线程完成检查点保存的同时继续进行。
然而,对于用户而言,效果几乎是一样的,即因检查点保存而导致的训练停机时间大幅减少,在许多情况下减少了 10 倍甚至 20 倍。

正如上方的加速图表所示,异步检查点在一年前实现的巨大改进基础上,又进一步提升了 10 倍到 23 倍的效率。
异步检查点是如何工作的?
异步检查点将检查点保存过程模块化为两个部分,而不是一个单体过程。第一阶段将数据从每个 GPU/Rank 复制到 CPU。这是用户可见的停机时间,对于 7B-13B 的模型大小,这可能需要 6-14 秒。第二阶段异步地将数据从 CPU 内存复制到磁盘以持久化检查点。
一旦数据在第一阶段被复制到 CPU,GPU 就可以立即恢复训练。因此,通过异步检查点,检查点保存的停机时间仅仅是复制最新模型状态到 CPU 所需的时间。
在训练恢复的同时,非阻塞 CPU 线程会处理内存中刚刚接收到的数据,以完成完整的检查点/序列化过程并写入磁盘(即持久保存)。

请注意,PyTorch 的分布式检查点器依赖于集合通信调用来获取每个 Rank 的元数据,这是优化保存所必需的,同时也依赖于一次最终同步来标记检查点保存完成并使该操作具有原子性。如果检查点线程使用了与训练相同的进程组,这可能会干扰分布式训练(因为分布式训练也依赖类似的调用来同步跨多个 GPU 的训练)。
具体来说,这些调用之间的竞争条件可能会导致训练和异步检查点保存线程同时等待集合通信调用,从而导致真正的集合通信挂起。
我们通过为异步检查点初始化一个单独的进程组来避免这种情况。这会将检查点保存的集合通信隔离到它们自己逻辑的进程组中,从而确保它不会干扰主训练线程中的集合通信调用。
如何在训练中使用异步检查点?
异步检查点的使用相对简单。使用最新版本的 PyTorch Nightly,你需要同时使用 nccl 和 gloo 来初始化进程组。Gloo 是 CPU 线程部分所必需的。
在此基础上,创建一个异步检查点将要使用的重复进程组。然后像往常一样进行训练,但在需要保存检查点时,使用异步保存 API,传入要保存的状态、检查点 ID 和检查点进程组。

异步检查点也已在 torchtitan 中完整实现。在这里,它被实现用于预训练你自己的 Llama2 或 Llama3 模型。使用它只需更新 toml 配置文件即可。

未来工作
过去一年中,检查点技术取得了巨大进步。从近半小时的检查点保存时间,到分布式检查点的 5 分钟以内,再到现在的异步检查点的 30 秒以内。
最后的疆界是“零开销检查点”——通过在反向传播过程中流式传输更新后的权重,使得在异步检查点启动时检查点数据已在 CPU 上,从而消除那 30 秒以内的停机时间。
这将有效地使大型模型训练实现零干扰或零停机,从而既能提高稳健性(因为可以更频繁地获取检查点),又能通过消除检查点停机时间来加快训练进度。
源代码链接: https://github.com/pytorch/pytorch/blob/main/torch/distributed/checkpoint/state_dict_saver.py