作者:Meta 的 Lucas Pasqualin, Less Wright, Iris Zhang (PyTorch), Chien-Chin Huang;IBM 研究院的 Swaminathan Sundararaman, Saransh Gupta, Raghu Ganti

摘要:得益于 PyTorch 分布式的新异步检查点保存功能(由 IBM 提供反馈开发),我们展示了 IBM 研究院团队如何实现并将有效检查点保存时间缩短 10-20 倍。例如:一个 7B 模型的检查点“停机时间”从平均 148.8 秒缩短到 6.3 秒,提速达 23.62 倍。

这直接意味着,在持续稳健地进行检查点保存的同时,每 24 小时内可以获得更多的净训练进度,或者可以更频繁地进行检查点保存以缩短恢复窗口/时间。

在本说明中,我们将展示实现异步检查点保存所需的用法代码和架构,以及经 IBM 研究团队验证的时间测试结果。

Async Checkpointing vs Standard Checkpointing

模型检查点保存是大型模型训练的关键组成部分,但检查点保存是一个昂贵的过程,因为每个检查点保存过程都会阻塞训练进度,以便保存最新的模型权重。然而,不进行检查点保存或降低检查点保存频率可能会导致训练进度显著损失。例如,死锁、慢节点和 GPU 错误等故障需要重新启动训练过程。为了从故障中重新启动,所有(训练)工作进程必须停止其训练过程,并从上次保存的检查点重新启动。

因此,故障鲁棒性与训练进度之间的固有矛盾表现为一种权衡,但现在有了异步检查点保存,PyTorch 分布式能够显著缓解这种矛盾,并在对总训练时间影响最小的情况下实现频繁的检查点保存。

作为背景,大约 一年前,我们展示了分布式检查点保存如何从原来的 torch.save() 功能大幅加速了检查点保存时间。正如 IBM 研究院所指出的,torch.save 在 PyTorch 1.13 版本中保存一个 11B 模型可能需要长达 30 分钟。

随着分布式检查点保存的进步,对于最大 30B 模型,检查点保存可以在 4 分钟内完成。

借助异步检查点保存,由于检查点保存而损失的训练时间现在减少到 30 秒以下,并且通常短至 6 秒。

需要明确的是,异步检查点保存并不会像之前的更新那样压缩实际的序列化检查点保存时间。相反,它将最终的检查点保存过程从关键路径(到 CPU 线程)上移开,从而允许 GPU 训练继续进行,同时在独立的线程中完成检查点保存。

然而,对用户而言,效果几乎相同:由于检查点保存导致的训练停机时间大幅减少,在许多情况下缩短了 10 倍甚至 20 倍。

Async Dist Checkpointing

如上图所示的加速图表,异步检查点保存比一年前已有的重大改进又带来了 10 倍到 23 倍的进一步提升。

异步检查点保存如何工作?

异步检查点保存将检查点保存过程模块化为两个部分,而不是一个单一的整体过程。第一阶段将数据从每个 GPU/rank 从 GPU 复制到 CPU。这是用户可见的停机时间,对于 7B-13B 模型大小,可能需要 6 到 14 秒。第二阶段将数据从 CPU 内存异步复制到磁盘以持久化检查点。

一旦数据在第一阶段复制到 CPU,GPU 就可以立即恢复训练。因此,使用异步检查点保存,检查点保存的停机时间仅仅是复制最新模型状态到 CPU 所需的时间。

在训练恢复的同时,非阻塞的 CPU 线程使用内存中刚到的数据来完成到磁盘的完整检查点保存/序列化过程(即持久保存)。

flow diagram

请注意,PyTorch 的分布式检查点保存器依赖于集合通信调用来获取优化保存所需的每个 rank 的元数据,以及一个标记检查点保存完成并使操作具有原子性的最终同步。如果检查点保存线程使用与训练相同的进程组,这可能会干扰分布式训练(因为分布式训练也依赖类似的调用来同步多个 GPU 的训练)。

具体来说,调用之间的竞争条件可能导致训练和异步检查点保存线程同时等待集合操作调用,从而导致真正的集合操作挂起。

我们通过为异步检查点保存初始化一个单独的进程组来避免这种情况。这将检查点保存集合操作分离到其自己的逻辑进程组中,从而确保它不会干扰主训练线程中的集合操作调用。

如何在我的训练中使用异步检查点保存?

异步检查点保存的使用相对简单。使用最新的 PyTorch nightly 版本,您需要使用 nccl 和 gloo 初始化您的进程组。gloo 是 CPU 线程部分所必需的。

然后,创建一个异步检查点保存将使用的重复进程组。之后像往常一样训练,但在您想要进行检查点保存的时候,使用异步保存 API,传入要保存的状态、检查点 ID 和检查点进程组。

Code snippet

异步检查点保存也已在 torchtitan 中完全实现。在这里,它被实现用于预训练您自己的 Llama2 或 Llama3 模型。使用它就像更新 toml 配置文件一样简单。

Code snippet

未来工作

过去一年,检查点保存取得了巨大进步。从将近半小时的检查点保存时间,通过分布式检查点保存缩短到 5 分钟以下,现在通过异步检查点保存缩短到 30 秒以下。

最后的疆界——零开销检查点保存,即使是 < 30 秒的停机时间也通过在反向传播过程中流式传输更新的权重来消除,使得在异步检查点保存开始时检查点数据已经在 CPU 上。

这将有效地使大型模型训练达到检查点保存不造成任何中断或停机,从而同时提高鲁棒性(因为可以更频繁地进行检查点保存)并加快训练进度,因为检查点保存没有停机时间。

源代码链接: https://github.com/pytorch/pytorch/blob/main/torch/distributed/checkpoint/state_dict_saver.py