跳转到主要内容
博客

PyTorch 异步检查点速度提升 6 倍,使用缓存计划,无 GIL 争用

作者: 2025 年 4 月 30 日2025 年 5 月 3 日暂无评论

Meta:Less Wright, Meet Vadakkanchery, Saurabh Mishra, Ela Krepska, Hamid Shojanazeri, Pradeep Fernando
Crusoe:Ethan Petersen, Martin Cala, Chip Smith

PyTorch DCP(分布式检查点)最近在异步检查点中启用了新的优化,通过最小化集体开销和提高整体检查点效率来减少 GPU 利用率下降。

使用 Crusoe 的 2K H200 集群,利用 TorchTitan 训练 Llama3-70B,我们能够验证这些新功能在 1856 个 GPU 规模下提供了显著的加速,将异步 DCP 检查点的后台处理时间从约 436 秒减少到约 67 秒。

这使得后台检查点处理时间缩短了约 6.5 倍,从而使更多的总训练时间能够以全训练吞吐量进行。

图 1:1856 次高频检查点训练运行。第一个检查点(tps 下降)没有缓存保存计划,其后台处理时间远长于使用缓存计划的其余检查点。

背景:什么是异步检查点?

在标准检查点工作流中,当检查点数据从 GPU 卸载到 CPU 然后写入存储时,GPU 会被阻塞。保存到物理介质完成后,训练才能恢复。

异步检查点通过允许实际保存到存储由 CPU 线程完成,大大减少了停机时间,从而在检查点数据并行持久化时,允许基于 GPU 的训练继续进行。它主要用于中间/容错检查点,因为它比同步检查点更快地解除 GPU 阻塞。
例如,在我们的S大规模实验中,当检查点数据从 GPU 移动到 CPU(暂存)时,GPU 训练被阻塞不到一秒(在 1856 规模下为 0.78 秒)。此时,GPU 训练立即继续,这比传统检查点大大改善了训练时间。有关参考,异步检查点在此处有更详细的介绍。

异步检查点的挑战

然而,异步检查点固有的后台处理存在额外的挑战,导致在存储阶段完成时训练吞吐量暂时下降。这些挑战如下。

GIL 争用导致的 GPU 利用率下降

Python 中的全局解释器锁(GIL)是一种机制,它阻止多个原生线程同时执行 Python 字节码。这种锁主要是因为 CPython 的内存管理不是线程安全的而必需的。

DCP 目前使用后台线程进行元数据收集和上传到存储。尽管这些耗时步骤是异步完成的,但它会导致与训练器线程争用 GIL。这会导致 GPU 利用率(QPS)显著下降,并增加端到端上传延迟。对于大规模检查点,CPU 并行处理的开销对净 GPU 训练速度有抑制作用,因为 CPU 也通过 GPU 内核启动来驱动训练过程。

请参考我们实验中的下图:

图 2:可以看到,即使暂存(即对训练器的阻塞操作)完成后,训练 QPS 仍持续下降。

图 2 中的第一个下降(由紫色线标记)表示暂存已完成,训练可以继续。然而,第二个下降是显而易见的(由紫色线和黄色线之间的区域标记),这是由于训练器线程和检查点线程争用 Python GIL 导致的,导致训练 QPS 下降,直到检查点线程完成执行。

集体通信成本

DCP 目前出于各种原因执行多次集体通信:去重、检查点的全局元数据、重新分片和分布式异常处理。集体通信成本高昂,因为它们需要网络 I/O 以及跨 GPU 网络发送的大量元数据的 pickle/unpickle。随着作业规模的增长,这些集体通信变得极其昂贵,导致显著更高的端到端延迟和集体超时可能性。

解决方案

基于进程的异步检查点

DCP 现在支持通过后台进程进行异步检查点保存。这有助于消除与训练器线程的 python GIL 争用,从而避免训练 QPS 下降。请参阅图 2(通过线程进行检查点)和图 3(通过后台进程进行检查点)。

保存计划缓存

DCP 在规划和存储 I/O 步骤之间有清晰的边界。DCP 中的 SavePlanner 是一个有状态组件,充当 state_dict 的访问代理。Planner 管理由各个 rank 准备的保存计划,这些计划包含执行写入 I/O 所需的元数据信息。规划步骤涉及一个集体操作,以在协调器 rank 上收集检查点的全面视图。协调器 rank 负责去重参数/权重以消除冗余,验证全局计划以确保准确性和一致性,并创建全局元数据结构。接下来是一个分散集体操作,协调器 rank 将 I/O 任务分配给每个 rank。对计划进行的任何转换都会影响存储组件最终写入数据的方式。

在训练作业过程中,会保存多个检查点。在大多数情况下,只有检查点数据在不同的保存实例之间发生变化,因此,计划保持不变。这为我们提供了缓存计划的机会,只在第一次保存时支付规划成本,然后将该成本分摊到所有后续尝试中。只有更新的计划(在下一次尝试中更改的计划)才通过集体通信发送,从而显著降低了集体开销。

实验结果

设置: 1856 个 H200 GPU,Llama3-70B,HSDP2 和 TorchTitan

部署上述两种解决方案后,关键结果如下:

  • TPS 下降显著收窄,峰值下降至 372 与 315 tps,且持续时间大大缩短(约 67 秒 vs 约 437 秒)。这个时间窗口现在主要归因于 CPU 处理的阻塞。
  • 由于规划阶段的开销非常低,后续检查点保存尝试也继续快得多。因此,端到端延迟提高了 6.5 倍以上。这将使我们的合作伙伴能够增加检查点频率,并减少训练进度的损失(即浪费的训练时间)。

如果您查看图 1 中的第一个下降峰值,GPU 处理时间在此次下降中将训练吞吐量从 700 降低到 320 tps,并抑制了大约 7 分钟(467 秒)。一旦 CPU 完成处理,训练将再次以全速继续。

以前,这种约 7 分钟的抑制会在每次检查点重复。然而,通过新的基于进程的检查点功能,只有第一个检查点具有完整的下降时间(主要由于守护进程初始化开销),因为所有未来的检查点都通过后台进程执行,从而缓解了与训练器线程的 GIL 争用。

这在所有后续检查点中都有视觉显示,其中平均 MFU 抑制时间降至略多于一分钟,这体现在几乎立即恢复到完整 MFU 吞吐量的急剧峰值中。

图 3:红色框显示了非缓存计划检查点,其中还包括检查点后台初始化进程开销,而紫色框突出显示了第一个使用缓存计划运行的检查点。

这意味着即使大规模检查点,例如图 2 中在 1856 GPU 规模下所示的检查点,也可以将训练吞吐量影响降低约 6 倍。这使得异步 DCP 检查点可以更频繁地运行(从而提供更好的回滚保护),同时相对于以前的异步检查点开销,提高了总训练吞吐量。

使用 DCP 的缓存检查点

此功能已作为 PyTorch 每夜构建的一部分提供,您可以直接在 TorchTitan 中测试 PyTorch 的异步 DCP 检查点。以下是启用这些功能的说明:

  • 基于进程的异步检查点
    • async_save API 中将 async_checkpointer_type 设置为 AsyncCheckpointerType.PROCESS。(文件:pytorch/torch/distributed/checkpoint/state_dict_saver.py)
  • 保存计划缓存
    • DefaultSavePlanner 中将 enable_plan_caching 标志设置为 true。(文件:pytorch/torch/distributed/checkpoint/default_planner.py)

未来工作

DCP 将推出更多优化,以进一步降低检查点成本。目前,即使保存计划已缓存,协调器 rank 仍然会准备元数据。对于具有许多张量的大型作业和模型,此开销不可忽略。在下一个迭代中,DCP 将消除元数据开销并进一步提高端到端延迟。DCP 还将引入其他优化,例如零开销检查点,以实现在大规模作业中高效的检查点。

敬请期待!