作者:Meta:Iris Zhang, Less Wright, Rodrigo Kumpera, Chien-Chin Huang,IBM:Davis Wertheimer, Supriyo Chakraboty, Sophia Wen, Raghu Ganti, Mudhakar Srivatsa, Seethrami Seelam

Params saved per minute

去年,IBM Research 开始与我们合作,为其大型基础模型引入完全分片数据并行 (FSDP)。 他们对此很感兴趣,因为 FSDP 是 PyTorch 原生产品,可用于扩展他们在 IBM Cloud 上的分布式训练工作。

我们很高兴地分享,通过与 IBM 的合作,我们已大幅提升大型模型的检查点速度(相比原始 PyTorch 1.13 保存速度提升 72 倍),验证了模型和优化器检查点可扩展至 300 亿参数,并实现了在 S3 后端使用 FSDP + 分布式检查点的云优先训练。

什么是分布式检查点?

分布式检查点是 PyTorch 原生解决方案,用于从多个 rank 保存和加载 PyTorch 模型和优化器状态,并支持在重新加载之间动态更改 world size。

Checkpoint time vs model params

PyTorch 分布式检查点 (DCP) API 在 PyTorch 1.13 中引入,并在 PyTorch 2.0 中作为官方原型功能包含。

分布式检查点在几个重要方面与 torch.save() 和 torch.load() 不同

  1. DCP 每个检查点生成多个文件,每个 rank 至少一个文件,
  2. DCP 就地操作,这意味着模型应首先分配其数据,然后分布式检查点将使用该存储。

从 1.13 到 2.0 的一项重大改进包括为 FSDP 模型检查点添加了 sharded_state_dict 支持。 这允许对更大尺寸的模型进行检查点操作,并增加了对加载时重新分片的支持。 加载时重新分片允许在一个集群拓扑中保存,并在另一个集群拓扑中加载。 此功能是用户强烈要求的,因为它允许训练作业在一个集群上运行、保存,然后在具有不同 world size 的不同集群上继续运行。

另一个重大变化是我们将存储层与检查点规划层解耦,并将两层的实现与接口分离。 通过此更改,用户现在可以指定在检查点规划阶段应如何对 state_dict 进行分块或转换。 此外,可自定义的存储层可以轻松适应不同的后端。

有关分布式检查点包的更多信息,请点击 此处

在生产环境中使用 IBM 实现高性能分布式检查点

IBM 在 2023 年 Think 大会上宣布了其 watsonx.ai 平台,用于开发和部署面向企业的基础模型。 该平台基于混合云构建,支持跨多种模态的用例,例如 NLP、时间序列、天气、化学、表格数据和网络安全,模型大小从数百万到数百亿参数不等。 模型架构范围从视觉 Transformer,到多模态 RoBERTa 风格的特征提取器,到类似于 T5、GPT 和 Llama 的大型生成语言模型。

截至目前,IBM 已为高达 110 亿参数的 T5 风格架构和高达 300 亿参数的解码器架构(GPT 风格)启用了检查点功能。

IBM 帮助我们确定这从内存和性能角度限制了 DCP 的扩展能力。 根据他们的建议,我们增强了 FileSystemWriter,使其为每个 rank 生成单个检查点,以减少读写开销。

以此选项作为新的默认设置,DCP 现在在检查点保存期间为每个 rank 创建一个单独的文件,然后在加载时读取参数时对其进行切片。

通过将 sharded_state_dict 支持与每个 rank 单个文件的写入器相结合,分布式检查点能够将检查点保存时间加速 72 倍以上(相比原始 PyTorch 1.13 保存速度),并为以前会超时的 150 亿以上模型尺寸的模型实现快速检查点。

“回顾过去,我们看到的加速效果确实令人震惊,这加速了我们对许多模型的训练处理。 我们从在 PyTorch 1.13 中花费近半小时写入单个 110 亿检查点,到能够在 3 分多钟内处理具有优化器和数据加载器状态的 300 亿参数模型(原始数据的八倍以上)。 这为我们工作的稳定性和效率带来了奇迹,因为我们将训练扩展到数百个 GPU。” – Davis Wertheimer, IBM Research

IBM 的采用还帮助我们在真实的、大规模训练环境中验证和改进我们的解决方案。 例如,IBM 发现 DCP 在具有多个 GPU 的单个节点上运行良好,但在多个节点上使用时会出错。

在调查该问题后,我们意识到我们假设写入类似 NFS 的共享文件系统,该系统假定具有强大的读后写一致性。 具有文件系统 API(如 S3FS)的对象存储提供最终一致性语义,因此导致在这种设置下分布式检查点失败。 通过与 IBM 合作,我们发现了这个问题,并通过进行 一行代码更改 修复了它,并为 DCP 启用了对象存储后端! 这种存储方法通常比共享文件系统便宜一个数量级,因此可以实现更细粒度的检查点。

寻求合作

如果您有兴趣尝试分布式检查点,请随时与我们联系!

如果您在尝试时遇到任何问题,可以在我们的 Github 代码库 中提交 issue。

致谢

这个项目如果没有众多合作者的帮助是不可能实现的。 我们要感谢 Yanli Zhao、Andrew Gu、Rohan Varma 对 FSDP 的支持。 感谢 Pritam Damania、Junjie Zhao 和 Wanchao Liang 对 ShardedTensor 的支持。