跳转到主要内容
博客

使用 PyTorch DCP 减少分布式检查点的存储空间和带宽占用

总结

PyTorch 分布式检查点 (DCP) 是一个多功能且强大的工具,用于在分布式训练环境中管理模型检查点。其模块化设计使开发人员能够根据其特定需求定制组件,使其成为各种用例的理想解决方案。

在这篇博文中,我们将展示如何利用 PyTorch DCP 的模块化特性集成压缩技术,从而将检查点大小减少了 22%。我们还将深入探讨我们定制方案的实现细节,提供实用的见解和指导,帮助您应用类似技术来优化自己的检查点工作流程并提高整体效率。

动机

大型分布式检查点

随着模型复杂性和规模的增加,分布式检查点成为训练过程中至关重要的组成部分。然而,由于检查点文件体积庞大,通常会导致巨大的存储需求和高昂的带宽成本。

压缩

为了应对这一挑战,压缩成为一种自然的解决方案。鉴于检查点主要由二进制数据(张量)组成,我们的目标是在最小化压缩开销的同时实现最佳压缩比。我们选择了 zstd 压缩算法,因为它高效且有效。

DCP

DCP 的模块化设计具有定义明确且易于扩展的组件,使其成为我们理想的检查点解决方案。

详细信息

自定义 StorageWriter

PyTorch DCP 的 StorageWriter 组件负责将检查点数据写入存储。我们通过修改 _FileSystemWriter 来定制该组件,它扩展了基础的 StorageWriter 类。_FileSystemWriter 类现在接受一个额外的参数 _extension,该参数是 StreamTransformExtension 的实例。

def save(
    state_dict: STATE_DICT_TYPE,
    *,
    checkpoint_id: Union[str, os.PathLike, None] = None,
    # We used a _FileSystemWriterextended as a storage writer component
    storage_writer: Optional[StorageWriter] = None, 
    planner: Optional[SavePlanner] = None,
    process_group: Optional[dist.ProcessGroup] = None,
    no_dist: bool = False,
) -> Metadata:

class _FileSystemWriter(StorageWriter):

    def __init__(
        self,
        path: Union[str, os.PathLike],
        single_file_per_rank: bool = True,
        sync_files: bool = True,
        thread_count: int = 1,
        per_thread_copy_ahead: int = 10_000_000,
        overwrite: bool = True,
 # We customized _FileSystemWriterextended to take in an extension
        _extensions: Optional[Sequence[StreamTransformExtension]] = None,
        serialization_format: SerializationFormat = SerializationFormat.TORCH_SAVE,
        *args: Any,
        **kwargs: Any,
    ) -> None:

StreamTransformExtension 是一个抽象类,定义了两个方法:transform_to(),用于对输出流进行操作;以及 transform_from(),用于对输入流进行操作。这些方法使我们能够对流数据执行自定义转换。

class StreamTransformExtension(Extension):

    @abc.abstractmethod
    def transform_to(self, output: IO[bytes]) -> IO[bytes]:

    @abc.abstractmethod
    def transform_from(self, input: IO[bytes]) -> IO[bytes]:

实现 ZStandard 压缩

我们实现了一个名为 ZStandardStreamTransformExtension 的具体子类,它使用 zstd 压缩算法 提供压缩功能。我们的 ZStandard 类实现了 transform_to() 方法来压缩传出的流数据,以及 transform_from() 方法来解压传入的流数据。

class ZStandard(StreamTransformExtension):

    def transform_to(self, output: IO[bytes]) -> IO[bytes]:
# Our compression implementation

    def transform_from(self, input: IO[bytes]) -> IO[bytes]:
# Our decompression implementation

结合定制化功能

最后,在保存检查点时,我们将我们自定义的 _FileSystemWriter 类与 ZStandard 压缩扩展相结合。我们编写了一个示例测试来演示所有部分如何协同工作。

fs_writer = FileSystemWriter(
          path=path,
          thread_count=thread_count,
         _extensions=[ZStandard()],
)

save(
         state_dict=state_dict_to_save,
         storage_writer=fs_writer,
)

评估

结果

我们与 IBM 合作,在其内部的一个训练集群上对我们提出的解决方案进行了评估。结果显示,检查点大小显著减少了 22%,但代价是压缩时间增加。然而,通过多线程技术,我们成功地缓解了这一权衡,将检查点时间的增加限制在仅 9%。这证明了我们的解决方案有潜力在减小检查点大小和保持性能之间取得平衡。

模型 每个 rank 的线程数 DCP 检查点大小 (GB) 检查点保存时间 (秒)
基准 ZStd 𝚫 基准 ZStd 𝚫
granite-3b-code-instruct 8 6.72 5.26 -21.8% 1.96 2.15 9.7%
4 6.72 5.26 -21.8% 1.98 2.38 20.2%
1 6.72 5.26 -21.8% 2.34 3.86 64.9%
granite-3.2-8b-instruct 8 15.6 12.08 –22.5% 3.37 3.65 8.3%
4 15.6 12.08 –22.5% 3.72 4.37 17.5%
1 15.6 12.08 –22.5% 5.37 8.45 57.4%

设置

我们选择了 IBM 的两个开源模型(Granite-3B-Code-Instruct-128KGranite-3.2-8B-Instruct)。为了进行评估,我们在 IBM 的 Vela AI 超级计算机上,使用 Alpaca 数据集对这些模型进行全参数 FSDP 微调。Vela 超级计算机位于 IBM 云中,其每个节点都配备了八个 80GB 的 A100 GPU,这些 GPU 通过 NVLink 和 NVSwitch 相互连接。此外,每个节点还有两颗第二代英特尔至强可扩展处理器(Cascade Lake)和 1.5TB 的 DRAM。我们为 Vela 的一个节点配置了以下资源:

测试平台

  • Openshift 4.14 集群
  • Pod:64 个 Intel Cascade Lake CPU 核心,800GB 主机内存,8 块 A100-80GB GPU
  • 以持久卷形式提供的存储选项
    • 1TB 本地 GPFS
    • S3 存储桶

工作负载

  • 使用 FSDP 进行全参数微调,每个 epoch 保存一次检查点

检查点配置

  • 使用 save_state_dict() 保存到存储
  • 每个 rank 1 到 8 个线程
  • 每个 rank 1 个文件
  • 8 个 rank

结论

PyTorch DCP 的模块化设计使开发人员能够根据特定用例定制其组件,从而开启了新的定制和可扩展性水平。通过定制 StorageWriter 组件并实现压缩扩展,我们显著减小了检查点的大小,从而降低了存储需求和带宽成本。

我们邀请您深入阅读我们的文档,并尝试各种扩展和修改,探索 PyTorch DCP 定制化的巨大潜力。欢迎在 PyTorch GitHub 上参与讨论,并与 PyTorch Checkpointing 团队联系(通过带有“oncall: distributed checkpointing”标签的 GitHub issue),分享您的经验、提出问题,并了解最新的开发动态!