跳转到主要内容
博客

使用 PyTorch DCP 减少分布式检查点的存储空间和带宽占用

总结

PyTorch 分布式检查点 (DCP) 是一个多功能且强大的工具,用于在分布式训练环境中管理模型检查点。其模块化设计使开发人员能够根据其特定要求定制其组件,使其成为各种用例的理想解决方案。

在这篇博文中,我们将展示我们如何利用 PyTorch DCP 的模块化来集成压缩,并将检查点大小减少 22%。我们还将深入探讨我们定制的实现细节,提供实用见解和指导,说明您如何应用类似的技术来优化自己的检查点工作流并提高整体效率。

动机

大型分布式检查点

随着模型复杂性和规模的增加,分布式检查点成为训练过程的关键组成部分。然而,由于其庞大的大小,这些检查点通常会导致巨大的存储需求和高昂的带宽成本。

压缩

为了解决这个挑战,压缩自然而然地成为解决方案。鉴于检查点主要由二进制数据(张量)组成,我们旨在以最小的压缩开销获得最佳压缩比。我们选择了 zstd 压缩算法,因为它高效且有效。

DCP

DCP 的模块化设计,具有定义明确且易于扩展的组件,使其成为我们检查点解决方案的理想选择。

详情

定制 StorageWriter

PyTorch DCP 的 StorageWriter 组件负责将检查点数据写入存储。我们通过修改 _FileSystemWriter 来定制此组件,该组件扩展了基本 StorageWriter 类。现在,_FileSystemWriter 类接受一个额外的参数 _extension,它是 StreamTransformExtension 的一个实例。

def save(
    state_dict: STATE_DICT_TYPE,
    *,
    checkpoint_id: Union[str, os.PathLike, None] = None,
    # We used a _FileSystemWriterextended as a storage writer component
    storage_writer: Optional[StorageWriter] = None, 
    planner: Optional[SavePlanner] = None,
    process_group: Optional[dist.ProcessGroup] = None,
    no_dist: bool = False,
) -> Metadata:

class _FileSystemWriter(StorageWriter):

    def __init__(
        self,
        path: Union[str, os.PathLike],
        single_file_per_rank: bool = True,
        sync_files: bool = True,
        thread_count: int = 1,
        per_thread_copy_ahead: int = 10_000_000,
        overwrite: bool = True,
 # We customized _FileSystemWriterextended to take in an extension
        _extensions: Optional[Sequence[StreamTransformExtension]] = None,
        serialization_format: SerializationFormat = SerializationFormat.TORCH_SAVE,
        *args: Any,
        **kwargs: Any,
    ) -> None:

StreamTransformExtension 是一个抽象类,它定义了两个方法:transform_to(),在输出流上调用;和 transform_from(),在输入流上调用。这些方法使我们能够对流数据执行自定义转换。

class StreamTransformExtension(Extension):

    @abc.abstractmethod
    def transform_to(self, output: IO[bytes]) -> IO[bytes]:

    @abc.abstractmethod
    def transform_from(self, input: IO[bytes]) -> IO[bytes]:

实现 ZStandard 压缩

我们实现了 StreamTransformExtension 的具体子类,名为 ZStandard,它使用 zstd 压缩算法提供压缩功能。我们的 ZStandard 类实现了 transform_to() 来压缩传出流数据,并实现了 transform_from() 来解压缩传入流数据。

class ZStandard(StreamTransformExtension):

    def transform_to(self, output: IO[bytes]) -> IO[bytes]:
# Our compression implementation

    def transform_from(self, input: IO[bytes]) -> IO[bytes]:
# Our decompression implementation

结合定制

最后,在保存检查点时,我们将自定义的 _FileSystemWriter 类与 ZStandard 压缩扩展相结合。我们编写了一个 示例测试 来演示所有这些如何协同工作。

fs_writer = FileSystemWriter(
          path=path,
          thread_count=thread_count,
         _extensions=[ZStandard()],
)

save(
         state_dict=state_dict_to_save,
         storage_writer=fs_writer,
)

评估

结果

我们与 IBM 合作,在他们的一个内部训练集群上评估了我们提出的解决方案。结果显示,检查点大小显著减少了 22%,尽管代价是压缩时间增加。然而,通过多线程,我们能够缓解这种权衡,并将检查点时间增加限制在仅 9%。这表明我们的解决方案有潜力在检查点大小减少和性能之间取得平衡。

模型 每等级线程数 DCP 检查点大小(GB) 检查点时间(秒)
基准 ZStd 𝚫 基准 ZStd 𝚫
granite-3b-code-instruct 8 6.72 5.26 -21.8% 1.96 2.15 9.7%
4 6.72 5.26 -21.8% 1.98 2.38 20.2%
1 6.72 5.26 -21.8% 2.34 3.86 64.9%
granite-3.2-8b-instruct 8 15.6 12.08 –22.5% 3.37 3.65 8.3%
4 15.6 12.08 –22.5% 3.72 4.37 17.5%
1 15.6 12.08 –22.5% 5.37 8.45 57.4%

设置

我们选择了 IBM 的两个开源模型(Granite-3B-Code-Instruct-128KGranite-3.2-8B-Instruct)。为了进行评估,我们使用 Alpaca 数据集IBM 的 Vela AI 超级计算机(位于 IBM 云中)上对这些模型进行全参数 FSDP 微调,并进行检查点。每个 Vela 节点有八个 80GB A100 GPU,它们通过 NVLink 和 NVSwitch 连接。此外,每个节点有两个第二代英特尔至强可扩展处理器(Cascade Lake)和 1.5TB DRAM。我们为 Vela 的一个节点配置了以下资源:

测试平台

  • Openshift 4.14 集群
  • Pod:64 个 Intel Cascade Lake CPU 内核,800GB 主机内存,8 个 A100-80GB GPU
  • 以持久卷形式提供的存储选项
    • 1TB 本地 GPFS
    • S3 存储桶

工作负载

  • 全参数 FSDP 微调,每个 epoch 进行检查点

检查点配置

  • 将 save_state_dict() 保存到存储
  • 每个等级 1 到 8 个线程
  • 每个等级 1 个文件
  • 8 个等级

结论

PyTorch DCP 的模块化设计使开发人员能够根据特定用例定制其组件,从而开启了新的自定义和可扩展性水平。通过定制 StorageWriter 组件并实现压缩扩展,我们显著减少了检查点大小,从而降低了存储要求和带宽成本。

我们邀请您深入了解我们的文档,并通过试验各种扩展和修改来探索 PyTorch DCP 定制的广阔可能性。加入 PyTorch GitHub 上的讨论,并与 PyTorch Checkpointing 团队联系(通过标签“oncall: distributed checkpointing”打开 GitHub 问题)分享您的经验、提出问题并及时了解最新进展!