
去年,IBM 研究院开始与我们合作,为他们的大型基础模型引入完全分片数据并行 (FSDP)。他们对此产生了兴趣,因为 FSDP 是 PyTorch 原生提供的解决方案,用于在 IBM Cloud 上扩展他们的分布式训练工作。
我们很高兴地宣布,通过与 IBM 的合作,我们已经为大型模型实现了显著的检查点加速(与原始 PyTorch 1.13 的保存速度相比,提升了 72 倍),证明了模型和优化器检查点可以扩展到 30 亿参数,并支持使用 FSDP + 分布式检查点在 S3 后端进行云优先训练。
什么是分布式检查点?
分布式检查点是 PyTorch 原生解决方案,用于从多个 rank 保存和加载 PyTorch 模型和优化器状态,并支持在重新加载之间动态更改世界大小。

PyTorch 分布式检查点 (DCP) API 在 PyTorch 1.13 中引入,并作为官方原型功能包含在 PyTorch 2.0 中。
分布式检查点与 torch.save() 和 torch.load() 在几个重要方面有所不同:
- DCP 每个检查点生成多个文件,每个 rank 至少一个文件。
- DCP 原地操作,这意味着模型应首先分配其数据,然后分布式检查点将使用该存储空间。
从 1.13 到 2.0 的一个主要改进是增加了对分片 state_dict 的支持,用于检查点 FSDP 模型。这允许对更大尺寸的模型进行检查点,并增加了对加载时重新分片的支持。加载时重新分片允许在一个集群拓扑中保存,然后加载到另一个集群拓扑中。此功能需求很高,因为它允许在某一个集群上运行训练作业,保存后,然后可以在具有不同世界大小的不同集群上继续。
另一个主要变化是,我们将存储层与检查点规划层解耦,并为这两层分离了实现和接口。通过此更改,用户现在可以在检查点规划阶段指定其 state_dict 应如何分块或转换。此外,可定制的存储层可以轻松适应不同的后端。
有关分布式检查点包的更多信息可以在此处找到。
与 IBM 合作,在生产环境中实现高性能分布式检查点
IBM 在 Think 2023 上宣布了其针对企业基础模型开发和部署的 watsonx.ai 平台。该平台建立在混合云之上,支持跨多种模式的用例,如自然语言处理、时间序列、天气、化学、表格数据和网络安全,模型大小从数亿到数百亿参数不等。模型架构包括视觉 Transformer、多模态 RoBERTa 风格的特征提取器,以及类似于 T5、GPT 和 Llama 的大规模生成式语言模型。
截至今天,IBM 已为 T5 风格的架构(最高 110 亿参数)和解码器架构(GPT 风格,最高 300 亿参数)启用了检查点。
IBM 帮助我们认识到,这限制了 DCP 从内存和性能角度的扩展能力。根据他们的建议,我们增强了 FileSystemWriter,使其每个 rank 生成单个检查点,以减少读写开销。
有了这个新默认选项,DCP 现在在检查点保存期间为每个 rank 创建一个文件,然后在加载时读取参数时进行切片。
通过将 sharded_state_dict 支持与每个 rank 的单个文件写入器相结合,分布式检查点能够将检查点保存时间加速 72 倍以上(与原始 PyTorch 1.13 的保存速度相比),并使超过 150 亿参数的模型能够进行快速检查点,而这些模型以前会简单地超时。
“回想起来,我们所看到的这些模型的训练速度提升着实令人震惊。我们从 PyTorch 1.13 中编写一个 110 亿参数的检查点需要将近半小时,到现在能够处理一个 300 亿参数的模型,包括优化器和数据加载器状态——这相当于超过八倍的原始数据——仅需 3 分钟多一点。这极大地提高了我们作业的稳定性和效率,因为我们将训练扩展到数百个 GPU。”—— Davis Wertheimer,IBM 研究院
IBM 的采用也帮助我们在真实的、大规模的训练环境中验证和改进了我们的解决方案。例如,IBM 发现 DCP 在单个节点上使用多个 GPU 时运行良好,但在多个节点上使用时却出错。
在调查此问题时,我们意识到我们假设写入 NFS 类的共享文件系统,这假定强读写一致性。带有文件系统 API 的对象存储(例如 S3FS)提供最终一致性语义,因此导致在此类设置中的分布式检查点失败。我们与 IBM 合作,发现了这个问题并通过 一行代码更改 修复了它,并为 DCP 启用了对象存储后端!这种存储方法通常比共享文件系统便宜一个数量级,因此可以实现更细粒度的检查点。
寻求合作
如果您有兴趣尝试分布式检查点,请随时与我们联系!
如果您在尝试时遇到任何问题,可以在我们的 Github 仓库中提出问题。
致谢
如果没有许多合作者的帮助,这个项目是不可能实现的。我们要感谢 Yanli Zhao、Andrew Gu、Rohan Varma 对 FSDP 的支持。感谢 Pritam Damania、Junjie Zhao 和 Wanchao Liang 对 ShardedTensor 的支持。