摘要: PyTorch 分布式新的异步检查点功能,在 IBM 的反馈下开发,我们展示了 IBM 研究团队如何实现并将有效检查点时间缩短 10-20 倍。例如:7B 模型检查点的“停机时间”从平均 148.8 秒缩短到 6.3 秒,速度提高 23.62 倍。
这直接意味着在每个给定的 24 小时内可以取得更多的净训练进展,同时继续进行可靠的检查点,或者更频繁地进行检查点以缩短恢复窗口/时间。
在本说明中,我们展示了使异步检查点成为可能的使用代码和架构,以及经 IBM 研究团队验证的计时结果。

模型检查点是大型模型训练的重要组成部分,但检查点是一个昂贵的过程,因为每个检查点过程都涉及阻塞训练进程以保存最新的模型权重。然而,不进行检查点或降低检查点频率可能会导致训练进程的显著损失。例如,死锁、拖延和 GPU 错误等故障需要重新启动训练进程。为了从故障中重新启动,所有(训练)工作者必须停止其训练进程并从上次保存的检查点重新启动。
因此,故障鲁棒性与训练进程之间的固有矛盾表现为一种权衡,但现在通过异步检查点,PyTorch Distributed 能够显著缓解这种矛盾,并以对整体训练时间影响最小的方式实现频繁检查点。
背景:大约就在一年前,我们展示了分布式检查点如何大幅加快检查点时间,超越了原始的 torch.save() 功能。正如 IBM 研究指出,torch.save 检查点单个 11B 模型可能需要长达 30 分钟(PyTorch 1.13)。
随着分布式检查点的进步,对于高达 30B 模型大小的检查点可以在 4 分钟内完成。
通过异步检查点,因检查点而损失的训练时间现在减少到 30 秒以下,通常短至 6 秒。
需要明确的是,异步检查点并没有像之前的更新所展示的那样压缩实际的序列化检查点时间。相反,它将最终的检查点过程从关键路径(到 CPU 线程)移开,以允许 GPU 训练继续进行,同时在单独的线程中完成检查点。
然而,对于用户来说,效果几乎相同,因为由于检查点造成的训练停机时间大大减少,在许多情况下减少了 10 倍甚至 20 倍。

如上图所示,异步检查点在去年已有的巨大改进基础上,又带来了 10 到 23 倍的进一步提升。
异步检查点如何工作?
异步检查点将检查点过程模块化为两个部分,而不是一个单一的整体过程。第一阶段将数据从每个 GPU/rank 从 GPU 复制到 CPU。这是用户可见的停机时间,对于 7B-13B 模型大小可能需要 6-14 秒。第二阶段将数据从 CPU 内存异步复制到磁盘以持久化检查点。
一旦数据在第一阶段复制到 CPU,GPU 就可以立即恢复训练。因此,使用异步检查点时,检查点的停机时间就是将最新模型状态复制到 CPU 所需的时间。
在训练恢复的同时,非阻塞的 CPU 线程与内存中刚到达的数据协作,完成向磁盘进行完整检查点/序列化过程(即持久保存)。

请注意,PyTorch 的分布式检查点依赖于集体通信调用来获取优化保存所需的每个 rank 元数据,以及一个最终的同步,该同步标志着检查点完成并使操作具有原子性。如果检查点线程使用与训练相同的进程组,这可能会干扰分布式训练(因为分布式训练也依赖于类似的调用来同步跨多个 GPU 的训练)。
具体来说,这些调用之间的竞争条件可能会导致训练和异步检查点保存线程同时等待集体调用,从而导致真正的集体挂起。
我们通过为异步检查点初始化一个独立的进程组来避免这种情况。这使得检查点集体操作拥有自己的逻辑进程组,从而确保它不会干扰主训练线程中的集体调用。
如何在我的训练中使用异步检查点?
异步检查点的使用相对简单。使用最新的 PyTorch 每夜版本,您需要使用 nccl 和 gloo 初始化您的进程组。Gloo 是 CPU 线程部分所必需的。
在此之后,创建一个异步检查点将使用的重复进程组。然后像往常一样训练,但在您想要检查点时,使用异步保存 API,传入要保存的状态、检查点 ID 和检查点进程组。

异步检查点也在 torchtitan 中完全实现。在这里,它被实现用于预训练您自己的 Llama2 或 Llama3 模型。使用它就像更新 toml 配置文件一样简单。

未来工作
在过去一年中,检查点技术取得了巨大的进步。从将近半小时的检查点时间缩短到分布式检查点的不到 5 分钟,现在又通过异步检查点缩短到 30 秒以内。
最后一个前沿——零开销检查点,即使是小于 30 秒的时间也被消除,通过在反向传播过程中流式传输更新的权重,使得在异步检查点启动时检查点数据已经在 CPU 上。
这将有效地使大型模型训练达到检查点零中断或零停机,从而实现更高的鲁棒性(因为可以更频繁地进行检查点)和更快的训练进度(因为没有检查点停机时间)。
源代码链接:https://github.com/pytorch/pytorch/blob/main/torch/distributed/checkpoint/state_dict_saver.py