博客

使用 PyTorch DCP 减少分布式检查点的存储空间和带宽占用

总结

PyTorch 分布式检查点 (DCP) 是一个功能强大且用途广泛的工具,用于在分布式训练环境中管理模型检查点。其模块化设计使开发人员能够根据特定需求定制其组件,使其成为广泛用例的理想解决方案。

在本博文中,我们将展示我们如何利用 PyTorch DCP 的模块化来集成压缩并实现检查点大小减少 22%。我们还将深入探讨我们定制的实现细节,提供实用见解和指南,说明您如何应用类似技术来优化自己的检查点工作流程并提高整体效率。

动机

大型分布式检查点

随着模型复杂性和规模的增加,分布式检查点成为训练过程的关键组成部分。然而,由于检查点尺寸庞大,它们通常会导致大量的存储需求和高昂的带宽成本。

压缩

为了应对这一挑战,压缩成为一种自然的解决方案。考虑到检查点主要由二进制数据(张量)组成,我们的目标是在压缩开销最小的情况下实现最佳压缩比。我们选择 zstd 压缩算法,因为它具有效率和效果。

DCP

DCP 的模块化设计,拥有定义明确且易于扩展的组件,使其成为我们检查点解决方案的理想选择。

详情

定制 StorageWriter

PyTorch DCP 的 StorageWriter 组件负责将检查点数据写入存储。我们通过修改 _FileSystemWriter 来定制此组件,该组件继承自基础 StorageWriter 类。现在, _FileSystemWriter 类接受一个额外的参数 _extension,它是一个 StreamTransformExtension 实例。

def save(
    state_dict: STATE_DICT_TYPE,
    *,
    checkpoint_id: Union[str, os.PathLike, None] = None,
    # We used a _FileSystemWriterextended as a storage writer component
    storage_writer: Optional[StorageWriter] = None, 
    planner: Optional[SavePlanner] = None,
    process_group: Optional[dist.ProcessGroup] = None,
    no_dist: bool = False,
) -> Metadata:

class _FileSystemWriter(StorageWriter):

    def __init__(
        self,
        path: Union[str, os.PathLike],
        single_file_per_rank: bool = True,
        sync_files: bool = True,
        thread_count: int = 1,
        per_thread_copy_ahead: int = 10_000_000,
        overwrite: bool = True,
 # We customized _FileSystemWriterextended to take in an extension
        _extensions: Optional[Sequence[StreamTransformExtension]] = None,
        serialization_format: SerializationFormat = SerializationFormat.TORCH_SAVE,
        *args: Any,
        **kwargs: Any,
    ) -> None:

StreamTransformExtension 是一个抽象类,它定义了两个方法: transform_to(),在输出流上调用;以及 transform_from(),在输入流上调用。这使我们能够对流数据执行自定义转换。

class StreamTransformExtension(Extension):

    @abc.abstractmethod
    def transform_to(self, output: IO[bytes]) -> IO[bytes]:

    @abc.abstractmethod
    def transform_from(self, input: IO[bytes]) -> IO[bytes]:

实现 ZStandard 压缩

我们实现了一个名为 ZStandardStreamTransformExtension 的具体子类,它使用 zstd 压缩算法 提供压缩功能。我们的 ZStandard 类实现了 transform_to() 来压缩传出的流数据,并实现了 transform_from() 来解压缩传入的流数据。

class ZStandard(StreamTransformExtension):

    def transform_to(self, output: IO[bytes]) -> IO[bytes]:
# Our compression implementation

    def transform_from(self, input: IO[bytes]) -> IO[bytes]:
# Our decompression implementation

组合自定义

最后,我们在保存检查点时,将我们的自定义 _FileSystemWriter 类与 ZStandard 压缩扩展结合起来。我们编写了一个 示例测试 来演示所有组件如何协同工作。

fs_writer = FileSystemWriter(
          path=path,
          thread_count=thread_count,
         _extensions=[ZStandard()],
)

save(
         state_dict=state_dict_to_save,
         storage_writer=fs_writer,
)

评估

结果

我们与 IBM 合作,在其内部训练集群上对其提出的解决方案进行了评估。结果显示检查点大小显著减少了 22%,尽管压缩时间有所增加。然而,通过多线程,我们能够缓解这一权衡,并将检查点时间增加限制在仅 9%。这证明了我们的解决方案在平衡检查点大小减少和性能方面的潜力。

模型 每个 Rank 的线程数 DCP 检查点大小(以 GB 为单位) 检查点时间(秒)
基准 ZStd 𝚫 基准 ZStd 𝚫
granite-3b-code-instruct 8 6.72 5.26 -21.8% 1.96 2.15 9.7%
4 6.72 5.26 -21.8% 1.98 2.38 20.2%
1 6.72 5.26 -21.8% 2.34 3.86 64.9%
granite-3.2-8b-instruct 8 15.6 12.08 –22.5% 3.37 3.65 8.3%
4 15.6 12.08 –22.5% 3.72 4.37 17.5%
1 15.6 12.08 –22.5% 5.37 8.45 57.4%

设置

我们选择了 IBM 的两个开源模型(Granite-3B-Code-Instruct-128KGranite-3.2-8B-Instruct)。在评估中,我们在 IBM 的 Vela AI 超级计算机(位于 IBM 云上)上,使用 Alpaca 数据集 对这些模型进行全参数 FSDP 微调。Vela 的每个节点都有八个 80GB A100 GPU,它们通过 NVLink 和 NVSwitch 连接。此外,每个节点有两个第二代 Intel Xeon 可扩展处理器(Cascade Lake)和 1.5TB DRAM。我们使用以下资源配置了一个 Vela 节点:

测试平台

  • Openshift 4.14 集群
  • Pod:64 个 Intel Cascade Lake CPU 核心,800GB 主机内存,8 个 A100-80GB GPU
  • 作为持久卷公开的存储选项
    • 1TB 本地 GPFS
    • S3 存储桶

工作负载

  • 全参数 FSDP 微调,每个 epoch 进行一次检查点

检查点配置

  • 将 state_dict() 保存到存储
  • 每个 Rank 1 到 8 个线程
  • 每个 Rank 1 个文件
  • 8 个 Rank

结论

PyTorch DCP 的模块化设计使开发人员能够根据特定用例定制其组件,从而实现新的定制和可扩展性水平。通过定制 StorageWriter 组件并实现压缩扩展,我们实现了显著的检查点大小缩减,从而降低了存储需求并减少了带宽成本。

我们邀请您通过深入研究我们的文档并尝试各种扩展和修改,来探索 PyTorch DCP 定制的广阔可能性。加入 PyTorch GitHub 上的讨论,并与 PyTorch Checkpointing 团队联系(带有标签“oncall: distributed checkpointing”的开放 GitHub issue),以分享您的经验,提出问题,并随时了解最新动态!