摘要 

PyTorch 分布式检查点 (DCP) 正在致力于解决互操作性障碍,以确保像 HuggingFace safetensors 这样流行的格式能够与 PyTorch 生态系统良好兼容。由于 HuggingFace 已成为推理和微调领域的主流格式,DCP 开始提供对 HuggingFace safetensors 的支持。该功能的第一个用户是 torchtune,他们现在可以通过 DCP API 直接读写 HuggingFace 格式,从而获得了更好的用户体验。

问题

由于 HuggingFace 使用广泛,拥有超过 500 万用户,许多机器学习工程师希望以 safetensors 格式保存和加载检查点,以便轻松地与其生态系统协作。通过在 DCP 中原生支持 safetensors 格式,我们简化了用户的检查点管理流程:

  • DCP 目前拥有自己的自定义格式,因此那些既想与 HuggingFace 模型协同工作,又想利用 DCP 性能优势 和功能的用户,此前必须构建自定义转换器和组件,以便在两个系统之间进行转换。
  • 无需用户每次都将检查点下载到本地存储再上传,现在 HuggingFace 模型可以直接保存到用户选择的、支持 fsspec 的存储中,并从中加载。

如何使用

从用户角度来看,使用 safetensors 唯一需要做的改变就是调用 load 时使用新的 加载规划器 (load planner)存储读取器 (storage reader),同样地,保存时使用新的 保存规划器 (save planner)存储写入器 (storage writer)

加载和保存 API 的调用方式如下


load(
	state_dict=state_dict,
	storage_reader=HuggingFaceStorageReader(path=path),
)

save(
	state_dict=state_dict,
	storage_writer=HuggingFaceStorageWriter(
				path=path,
				fqn_to_index_mapping=mapping
			),
)

HuggingFaceStorageReader 和 HuggingFaceStorageWriter 可以接收任何基于 fsspec 的路径,因此可以将 HF safetensors 格式读写到任何支持 fsspec 的后端,包括本地存储和 HF 存储。由于 HuggingFace safetensors 元数据原生提供的信息级别与 DCP 元数据不同,目前这些 API 对分布式检查点的支持尚不完善,但 DCP 计划在未来原生支持此功能。

 

torchtune

我们 HuggingFace DCP 支持的首个客户是 torchtune —— 一个用原生 PyTorch 编写的后训练库。torchtune 用户检索模型权重的主要方式是从 Hugging Face Hub。以前,用户必须下载模型权重,并通过额外的 CLI 命令上传训练好的检查点;而 新的 DCP API 允许他们直接对 HuggingFace 进行读写,从而显著提升了用户体验。 

此外,DCP 中对 safetensor 序列化的支持极大地简化了 torchtune 中的检查点代码。不再需要针对特定格式的检查点解决方案,从而提高了项目的开发效率。

未来工作

DCP 计划通过重新分片 (resharding) 来处理 HuggingFace safetensors 检查点的分布式加载和保存。DCP 还计划支持将最终检查点合并为单个文件以供发布的功能。