作者:Meta:Iris Zhang, Less Wright, Rodrigo Kumpera, Chien-Chin Huang, IBM:Davis Wertheimer, Supriyo Chakraboty, Sophia Wen, Raghu Ganti, Mudhakar Srivatsa, Seethrami Seelam

Params saved per minute

去年,IBM 研究院开始与我们合作,为其大型基础模型引入 Fully Sharded Data Parallelism (FSDP)。他们对此很感兴趣,因为 FSDP 是 PyTorch 原生的产品,可用于在 IBM Cloud 上扩展其分布式训练工作。

我们很高兴地宣布,通过与 IBM 合作,我们已为大型模型实现了显著的检查点速度提升(与 PyTorch 1.13 原始保存速度相比提升 72 倍),证明了模型和优化器检查点可扩展到 300 亿参数,并支持在 S3 后端上使用 FSDP + 分布式检查点进行云优先训练。

什么是分布式检查点?

分布式检查点是 PyTorch 原生的解决方案,用于从多个 Rank 保存和加载 PyTorch 模型和优化器状态,并支持在重新加载时动态更改 world size。

Checkpoint time vs model params

PyTorch 分布式检查点 (DCP) API 在 PyTorch 1.13 中引入,并作为官方原型特性包含在 PyTorch 2.0 中。

分布式检查点与 torch.save() 和 torch.load() 在几个重要方面有所不同

  1. DCP 为每个检查点生成多个文件,每个 Rank 至少一个文件,
  2. DCP 就地操作,这意味着模型应首先分配其数据,然后分布式检查点将使用该存储空间。

从 1.13 到 2.0 的一项重大改进包括为 FSDP 模型检查点添加了 sharded_state_dict 支持。这允许对更大规模的模型进行检查点,并添加了加载时重新分片的支持。加载时重新分片可以在一种集群拓扑中保存,然后加载到另一种拓扑中。此特性需求量很大,因为它允许训练作业在一个集群上运行、保存,然后在具有不同 world size 的另一个集群上继续。

另一个重大变化是我们解耦了存储层和检查点规划层,并分离了这两层的实现和接口。通过此更改,用户现在可以在检查点规划阶段指定其 state_dict 如何进行分块或转换。此外,可定制的存储层可以轻松适应不同的后端。

有关分布式检查点包的更多信息,请点击此处查看。

在生产环境中与 IBM 合作实现高性能分布式检查点

IBM 在 Think 2023 大会上发布了用于开发和部署企业基础模型的 watsonx.ai 平台。该平台构建在混合云之上,支持跨多种模态(如 NLP、时间序列、天气、化学、表格数据和网络安全)的使用场景,模型规模从数亿到数百亿参数不等。模型架构包括视觉 Transformer、多模态 RoBERTa 风格特征提取器,以及类似于 T5、GPT 和 Llama 的大规模生成式语言模型。

截至今天,IBM 已为参数高达 110 亿的 T5 风格架构和参数高达 300 亿的解码器架构(GPT 风格)启用了检查点功能。

IBM 帮助我们发现这从内存和性能角度限制了 DCP 的扩展能力。根据他们的建议,我们增强了 FileSystemWriter,使其每个 Rank 生成单个检查点文件,以减少读写开销。

将此选项作为新的默认设置后,DCP 现在在保存检查点时为每个 Rank 创建一个文件,然后在加载时读取参数时进行切片。

通过将 sharded_state_dict 支持与每个 Rank 单文件写入器结合,分布式检查点保存时间比 PyTorch 1.13 原始保存速度提高了 72 倍以上,并能够为参数量超过 150 亿的模型实现快速检查点,这些模型之前会简单地超时。

“回想起来,我们在处理许多此类模型的训练时所看到的速度提升确实令人惊叹。在 PyTorch 1.13 中,写一个 110 亿参数的检查点需要将近半小时,而现在我们只需 3 分钟多一点的时间就能处理一个包含优化器和数据加载器状态的 300 亿参数模型——原始数据量是以前的八倍多。随着我们将训练扩展到数百个 GPU,这极大地提升了我们作业的稳定性和效率。”——Davis Wertheimer,IBM 研究院

IBM 的采用也帮助我们在真实的、大规模的训练环境中验证和改进了我们的解决方案。例如,IBM 发现 DCP 在具有多个 GPU 的单节点上运行良好,但在多节点上使用时会出错。

在调查问题后,我们意识到我们假设写入类似 NFS 的共享文件系统,该系统假定写后读强一致性。而提供文件系统 API 的对象存储(如 S3FS)提供最终一致性语义,从而导致在这种设置下的分布式检查点失败。通过与 IBM 合作,我们发现了这个问题,并通过修改一行代码修复了它,并为 DCP 启用了对象存储后端!此类存储方法的成本通常比共享文件系统低一个数量级,从而能够实现更精细的检查点。

寻求合作

如果您有兴趣尝试分布式检查点,请随时与我们联系!

如果您在尝试时遇到任何问题,可以在我们的 Github 仓库中提交Issue

致谢

如果没有许多合作者的帮助,这个项目是不可能完成的。我们要感谢 Yanli Zhao、Andrew Gu、Rohan Varma 对 FSDP 的支持。感谢 Pritam Damania、Junjie Zhao 和 Wanchao Liang 对 ShardedTensor 的支持。