Meta:Less Wright, Meet Vadakkanchery, Saurabh Mishra, Ela Krepska, Hamid Shojanazeri, Pradeep Fernando
Crusoe:Ethan Petersen, Martin Cala, Chip Smith
PyTorch DCP(分布式检查点)最近在异步检查点功能中实现了多项新优化,通过最小化集合通信开销并提高整体检查点效率,显著降低了 GPU 利用率的下降幅度。
在使用 Crusoe 的 2K H200 集群、运行 TorchTitan 并训练 Llama3-70B 模型的测试中,我们验证了这些新功能在 1856 个 GPU 的规模下带来了显著的速度提升,将异步 DCP 检查点的后台处理时间从约 436 秒缩短至约 67 秒。
这相当于后台检查点处理时间缩短了约 6.5 倍,从而使总训练时长中能够以全速吞吐量运行的时间占比大幅增加。

图 1:1856 个 GPU 的训练运行,伴随高频检查点保存。第一次检查点(TPS 下降处)没有缓存保存计划,其后台处理时间比使用缓存计划的后续检查点长得多。
背景:什么是异步检查点?
在标准的检查点工作流中,当检查点数据从 GPU 卸载到 CPU 并写入存储时,GPU 会处于阻塞状态。只有在物理介质保存完成后,训练才能恢复。
异步检查点通过允许 CPU 线程负责实际的存储写入,从而极大地减少了这种停机时间。这使得 GPU 能够一边进行训练,一边并行持久化检查点数据。它主要用于中间检查点或容错检查点,因为它比同步检查点能更快地解除对 GPU 的阻塞。
例如,在我们的超大规模实验中,当检查点数据从 GPU 移动到 CPU(暂存)时,GPU 训练仅被阻塞了不到一秒(在 1856 规模下为 0.78 秒)。随后,GPU 训练立即继续,这比传统检查点大大改善了训练时间。有关异步检查点的更多详细信息,请参阅此处。
异步检查点的挑战
然而,异步检查点固有的后台处理过程带来了一些额外的挑战,导致在完成存储阶段时训练吞吐量会暂时下降。具体如下所述。
GIL 争用导致的 GPU 利用率下降
Python 的全局解释器锁 (GIL) 是一种机制,用于防止多个原生线程同时执行 Python 字节码。这种锁之所以必要,主要是因为 CPython 的内存管理不是线程安全的。
DCP 目前使用后台线程进行元数据集合通信和上传至存储。尽管这些耗时步骤是异步执行的,但它会导致与训练器线程争用 GIL。这不仅严重影响了 GPU 利用率 (QPS),还增加了端到端的上传延迟。对于大规模检查点,CPU 并行处理的开销对净 GPU 训练速度有抑制作用,因为 CPU 同时还需要通过启动 GPU 内核来驱动训练过程。
请参考我们实验中的下图:
图 2:可以看到,即使在暂存(即对训练器的阻塞操作)完成后,训练 QPS 仍出现持续下降。
图 2 中的第一个凹陷(由紫线标记)表明暂存已完成,训练可以继续。然而,第二个明显的下降(紫线和黄线之间的区域)是由于训练器线程和检查点线程争用 Python GIL 造成的,导致在检查点线程执行完成之前,训练 QPS 一直处于受损状态。
集合通信成本
DCP 目前出于多种原因执行多次集合通信:去重、检查点的全局元数据、重新分片以及分布式异常处理。集合通信成本高昂,因为它们需要网络 I/O,并涉及在 GPU 网络间传输的大量元数据进行序列化与反序列化。随着任务规模的扩大,这些集合通信变得极其昂贵,导致端到端延迟显著增加,并可能触发集合通信超时。
解决方案
基于进程的异步检查点
DCP 现在支持通过后台进程进行异步检查点保存。这消除了与训练器线程之间的 Python GIL 争用,从而避免了训练 QPS 的下降。请对比图 2(线程方式)和图 3(后台进程方式)。
保存计划缓存
DCP 在规划阶段和存储 I/O 阶段之间有明确的界限。DCP 中的 SavePlanner 是一个有状态组件,充当 state_dict 的访问代理。规划器管理由各个分片 (Rank) 准备的保存计划,这些计划携带执行写入 I/O 所需的元数据信息。规划步骤包含一次集合通信操作,用于在协调器分片 (coordinator rank) 上收集检查点的全局视图。协调器分片负责对参数/权重进行去重以消除冗余,验证全局计划以确保准确性和一致性,并创建全局元数据结构。随后进行一次散播 (scatter) 集合通信,由协调器分片为每个分片分配 I/O 任务。对计划所做的任何转换都会影响存储组件最终写入数据的方式。
在训练任务过程中,会保存多个检查点。在大多数情况下,两次保存之间只有检查点数据发生变化,而计划保持不变。这为我们提供了缓存计划的机会:仅在第一次保存时支付规划成本,然后将该成本分摊到后续所有尝试中。只有更新后的计划(即在下一次尝试中发生了变化的计划)才会被发送,从而显著降低了集合通信开销。
实验结果
设置: 1856 个 H200 GPU,Llama3-70B 模型,HSDP2 和 TorchTitan
在部署了上述两个解决方案后,主要结果如下:
- TPS 下降幅度显著减小,峰值回落至 372 TPS(此前为 315 TPS),且时间窗口大幅缩短(约 67 秒对比约 437 秒)。目前该时间窗口主要归因于 CPU 处理过程的阻塞。
- 由于规划阶段的开销极低,后续的检查点保存尝试速度依然很快。因此,端到端延迟改善了超过 6.5 倍。这将允许我们的合作伙伴增加检查点保存频率,从而减少训练进度的损失(即浪费的训练时间)。
如果你观察图 1 中的第一个下冲峰,这种 GPU 处理时间的下降将训练吞吐量从 700 TPS 降低到 320 TPS,并压制了大约 7 分钟(467 秒)。一旦 CPU 完成处理,训练就会再次全速进行。
以前,这种约 7 分钟的压制会在*每个*检查点重复出现。然而,通过新的基于进程的检查点功能,只有第一个检查点会有完整的压制时间(主要是由于守护进程初始化的开销),后续所有检查点都通过后台进程执行,从而减轻了与训练器线程的 GIL 争用。
这一点在所有后续检查点中得到了直观展示:平均 MFU 压制时间缩短至仅一分钟多一点,表现为迅速反弹回全 MFU 吞吐量的陡峭尖峰。
图 3:红框显示了非缓存计划的检查点,其中还包括检查点后台初始化进程的开销;紫色框突出了第一个使用缓存计划运行的检查点。
这意味着即使是如图 2 所示的 1856 个 GPU 规模的大规模检查点,其对训练吞吐量的影响也可以降低约 6 倍。这使得异步 DCP 检查点能够更频繁地运行(从而提供更好的回滚保护),同时相比之前的异步检查点开销,整体训练吞吐量也得到了增强。
如何使用 DCP 的缓存检查点
此功能已包含在 PyTorch 每日构建版 (Nightly builds) 中,你可以直接在 TorchTitan 中测试 PyTorch 的异步 DCP 检查点。启用这些功能的说明如下:
- 基于进程的异步检查点
- 在 async_save API 中将 async_checkpointer_type 设置为 AsyncCheckpointerType.PROCESS。(文件:pytorch/torch/distributed/checkpoint/state_dict_saver.py)
- 保存计划缓存
- 在 DefaultSavePlanner 中将 enable_plan_caching 标志设置为 true。(文件:pytorch/torch/distributed/checkpoint/default_planner.py)
未来工作
DCP 将推出更多优化措施以进一步降低检查点成本。目前,尽管保存计划已被缓存,但协调器分片仍需准备元数据。对于大型作业和拥有许多张量的模型,这种开销不容忽视。在下一次迭代中,DCP 将消除元数据开销并进一步优化端到端延迟。DCP 还将引入更多优化,例如零开销检查点,以实现大规模作业中的高效检查点保存。
敬请期待!