
去年,IBM 研究院开始与我们合作,为其大型基础模型引入全分片数据并行 (FSDP) 功能。他们对 FSDP 感兴趣,因为它是 PyTorch 原生提供的解决方案,可用于在 IBM 云上扩展其分布式训练工作。
我们很高兴地宣布,与 IBM 合作,我们已大幅提升大型模型的检查点速度(比原始 PyTorch 1.13 保存速度快 72 倍),证明了模型和优化器检查点可扩展到 30B 参数,并支持使用 FSDP + 分布式检查点在 S3 后端进行云优先训练。
什么是分布式检查点?
分布式检查点是 PyTorch 原生解决方案,用于从多个 rank 保存和加载 PyTorch 模型和优化器状态,并支持在重新加载之间动态更改世界大小。

PyTorch 分布式检查点 (DCP) API 在 PyTorch 1.13 中引入,并作为官方原型功能包含在 PyTorch 2.0 中。
分布式检查点与 torch.save() 和 torch.load() 在几个重要方面有所不同
- DCP 为每个检查点生成多个文件,每个 rank 至少有一个文件,
- DCP 在原地操作,这意味着模型应首先分配其数据,然后分布式检查点将使用该存储。
从 1.13 到 2.0 的一个重大改进包括添加了对 FSDP 模型检查点的 sharded_state_dict 支持。这允许对更大尺寸的模型进行检查点,并增加了对加载时重新分片的支持。加载时重新分片支持在一个集群拓扑中保存,并加载到另一个集群中。此功能被高度请求,因为它允许训练作业在一个集群上运行,保存,然后在一个具有不同世界大小的不同集群上继续。
另一个重大变化是我们解耦了存储层和检查点规划层,并将两个层的实现与接口分开。通过此更改,用户现在可以在检查点规划阶段指定他们的 state_dict 应该如何分块或转换。此外,可定制的存储层可以轻松适应不同的后端。
有关分布式检查点包的更多信息,请参阅此处。
与 IBM 合作,在生产环境中实现高性能分布式检查点
IBM 在 2023 年 Think 大会上宣布了其 watsonx.ai 平台,用于企业基础模型的开发和部署。该平台基于混合云构建,支持跨多种模式的用例,如自然语言处理、时间序列、天气、化学、表格数据和网络安全,模型大小从数亿到数百亿参数不等。模型架构包括视觉 transformer、多模态 RoBERTa 风格的特征提取器,以及类似于 T5、GPT 和 Llama 的大规模生成式语言模型。
截至目前,IBM 已经为 T5 风格的架构(最高 11B 参数)和解码器架构(GPT 风格,最高 30B 参数)实现了检查点。
IBM 帮助我们识别出这会从内存和性能角度限制 DCP 的扩展能力。根据他们的建议,我们增强了 FileSystemWriter,使其为每个 rank 生成单个检查点,以减少读写开销。
将此选项作为新的默认设置,DCP 现在在检查点保存期间为每个 rank 创建单个文件,然后在加载时读取参数时将其切片。
通过结合 sharded_state_dict 支持和每个 rank 单个文件写入器,分布式检查点能够将检查点保存时间加速 72 倍以上,相比原始 PyTorch 1.13 保存速度,并实现对超过 15B 模型尺寸的快速检查点,而这些模型之前会直接超时。
“回首往事,我们所看到的速度提升真是令人惊叹,处理许多这些模型的训练。我们从 PyTorch 1.13 中写入一个 11B 检查点需要近半小时,到现在能够处理一个 30B 参数模型,包括优化器和数据加载器状态——这意味着原始数据量增加了八倍以上——只需 3 分钟多一点。这极大地提高了我们作业的稳定性和效率,因为我们将训练扩展到了数百个 GPU。”—— Davis Wertheimer,IBM 研究院
IBM 的采用也帮助我们在真实世界的大规模训练环境中验证和改进了我们的解决方案。例如,IBM 发现 DCP 在单个节点、多个 GPU 的情况下运行良好,但在多个节点上使用时却出错。
在调查该问题后,我们意识到我们假设写入到类似 NFS 的共享文件系统,这假设了强烈的读写一致性。而像 S3FS 这样的具有文件系统 API 的对象存储提供了最终一致性语义,因此导致分布式检查点在此类设置中失败。通过与 IBM 合作,我们识别了此问题并通过 一行代码更改 修复了它,并为 DCP 启用了对象存储后端!此类存储方法通常比共享文件系统便宜一个数量级,因此可以实现更细粒度的检查点。
寻求合作
如果您有兴趣尝试分布式检查点,请随时与我们联系!
如果您在尝试时遇到任何问题,可以在我们的 Github 仓库中提出问题。
致谢
如果没有许多协作者的帮助,这个项目是不可能完成的。我们要感谢 Yanli Zhao、Andrew Gu、Rohan Varma 对 FSDP 的支持。感谢 Pritam Damania、Junjie Zhao 和 Wanchao Liang 对 ShardedTensor 的支持。