摘要
PyTorch 分布式检查点 (DCP) 正在投入解决互操作性障碍,以确保流行的格式(如 HuggingFace safetensors)能够与 PyTorch 生态系统良好协作。由于 HuggingFace 已成为推理和微调领域的主导格式,DCP 开始支持 HuggingFace safetensors。这些更改的第一个客户是 torchtune,他们看到了改进的用户体验,因为他们现在可以直接使用 DCP API 清晰地读写 HuggingFace。
问题
由于 HuggingFace 广泛使用,拥有超过 500 万用户,许多机器学习工程师希望以 safetensors 格式保存和加载他们的检查点,以便轻松地与其生态系统协同工作。通过在 DCP 中原生支持 safetensors 格式,检查点对我们的用户来说通过以下方式得到简化:
- DCP 目前有自己的自定义格式,因此希望使用 HuggingFace 模型但利用 DCP 性能优势和功能的用户必须构建自定义转换器和组件,以便他们可以在两个系统之间工作。
- 用户不再需要每次都将检查点下载和上传到本地存储,现在可以直接将 HuggingFace 模型保存和加载到他们选择的 fsspec 支持的存储中。
如何使用
从用户的角度来看,使用 safetensors 所需的唯一更改是使用新的 加载规划器和 存储读取器调用加载,类似地使用新的 保存规划器和 存储写入器调用保存。
加载和保存 API 的调用方式如下
load(
state_dict=state_dict,
storage_reader=HuggingFaceStorageReader(path=path),
)
save(
state_dict=state_dict,
storage_writer=HuggingFaceStorageWriter(
path=path,
fqn_to_index_mapping=mapping
),
)
HuggingFaceStorageReader 和 HuggingFaceStorageWriter 可以接受任何基于 fsspec 的路径,因此它可以以 HF safetensors 格式读写到任何 fsspec 支持的后端,包括本地存储和 HF 存储。由于 HuggingFace safetensors 元数据不原生提供与 DCP 元数据相同级别的信息,因此这些 API 目前不支持分布式检查点,但 DCP 计划在未来原生支持此功能。
torchtune
我们 HuggingFace DCP 支持的第一个客户是 torchtune——一个用原生 PyTorch 编写的训练后库。torchtune 用户获取模型权重的主要方式是来自 Hugging Face Hub。以前,用户必须通过额外的 CLI 命令下载模型权重并上传训练好的检查点;新的 DCP API 允许他们直接读写 HuggingFace,从而带来更好的用户体验。
此外,DCP 中对 safetensor 序列化的支持大大简化了 torchtune 中的检查点代码。将不再需要特定于格式的检查点解决方案,从而提高了项目中的开发效率。
未来工作
DCP 计划处理 HuggingFace safetensors 检查点的分布式加载和保存以及重新分片。DCP 还计划支持将合并的最终检查点生成为单个文件以供发布的能力。