• 教程 >
  • 分布式检查点 (DCP) 入门
快捷方式

分布式检查点 (DCP) 入门

创建日期:2023 年 10 月 2 日 | 最后更新:2024 年 10 月 30 日 | 最后验证:2024 年 11 月 5 日

作者Iris Zhang, Rodrigo Kumpera, Chien-Chin Huang, Lucas Pasqualin

注意

编辑github 上查看和编辑本教程。

前提条件

在分布式训练期间对 AI 模型进行检查点操作可能具有挑战性,因为参数和梯度会分布在不同的训练器上,并且当你恢复训练时,可用训练器的数量可能会发生变化。PyTorch 分布式检查点 (DCP) 可以帮助简化这个过程。

在本教程中,我们将演示如何使用 DCP API 对简单的 FSDP 封装模型进行操作。

DCP 工作原理

torch.distributed.checkpoint() 允许从多个 rank 并行保存和加载模型。你可以使用此模块在任意数量的 rank 上并行保存,然后在加载时根据不同的集群拓扑结构重新分片。

此外,通过使用 torch.distributed.checkpoint.state_dict() 中的模块,DCP 支持在分布式环境中优雅地处理 state_dict 的生成和加载。这包括管理跨模型和优化器的全限定名 (FQN) 映射,以及为 PyTorch 提供的并行性设置默认参数。

DCP 在几个重要方面与 torch.save()torch.load() 不同:

  • 每个检查点会生成多个文件,每个 rank 至少生成一个。

  • 它原地操作,这意味着模型应该首先分配其数据,然后 DCP 会使用该存储空间。

  • DCP 对有状态对象(正式定义在 torch.distributed.checkpoint.stateful 中)提供特殊处理,如果定义了 state_dictload_state_dict 方法,DCP 会自动调用它们。

注意

本教程中的代码在 8 块 GPU 服务器上运行,但可以轻松泛化到其他环境。

如何使用 DCP

这里我们使用一个用 FSDP 封装的玩具模型进行演示。类似地,这些 API 和逻辑可以应用于更大的模型进行检查点操作。

保存

现在,我们创建一个玩具模块,用 FSDP 封装它,输入一些虚拟数据,然后保存。

import os

import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
import torch.multiprocessing as mp
import torch.nn as nn

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
from torch.distributed.checkpoint.stateful import Stateful
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType

CHECKPOINT_DIR = "checkpoint"


class AppState(Stateful):
    """This is a useful wrapper for checkpointing the Application State. Since this object is compliant
    with the Stateful protocol, DCP will automatically call state_dict/load_stat_dict as needed in the
    dcp.save/load APIs.

    Note: We take advantage of this wrapper to hande calling distributed state dict methods on the model
    and optimizer.
    """

    def __init__(self, model, optimizer=None):
        self.model = model
        self.optimizer = optimizer

    def state_dict(self):
        # this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT
        model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer)
        return {
            "model": model_state_dict,
            "optim": optimizer_state_dict
        }

    def load_state_dict(self, state_dict):
        # sets our state dicts on the model and optimizer, now that we've loaded
        set_state_dict(
            self.model,
            self.optimizer,
            model_state_dict=state_dict["model"],
            optim_state_dict=state_dict["optim"]
        )

class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(16, 16)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(16, 8)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))


def setup(rank, world_size):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12355 "

    # initialize the process group
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)


def cleanup():
    dist.destroy_process_group()


def run_fsdp_checkpoint_save_example(rank, world_size):
    print(f"Running basic FSDP checkpoint saving example on rank {rank}.")
    setup(rank, world_size)

    # create a model and move it to GPU with id rank
    model = ToyModel().to(rank)
    model = FSDP(model)

    loss_fn = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.1)

    optimizer.zero_grad()
    model(torch.rand(8, 16, device="cuda")).sum().backward()
    optimizer.step()

    state_dict = { "app": AppState(model, optimizer) }
    dcp.save(state_dict, checkpoint_id=CHECKPOINT_DIR)

    cleanup()


if __name__ == "__main__":
    world_size = torch.cuda.device_count()
    print(f"Running fsdp checkpoint example on {world_size} devices.")
    mp.spawn(
        run_fsdp_checkpoint_save_example,
        args=(world_size,),
        nprocs=world_size,
        join=True,
    )

请检查 checkpoint 目录。你应该看到 8 个检查点文件,如下所示。

Distributed Checkpoint

加载

保存后,我们创建相同的 FSDP 封装模型,并从存储中加载保存的 state dict 到模型中。你可以在相同或不同的 world size 下加载。

请注意,在加载之前,你必须调用 model.state_dict() 并将其传递给 DCP 的 load_state_dict() API。这与 torch.load() 根本不同,因为 torch.load() 只需要检查点路径即可加载。我们需要在加载前获取 state_dict 的原因如下:

  • DCP 使用模型 state_dict 中预分配的存储空间从检查点目录加载。加载期间,传入的 state_dict 将被原地更新。

  • DCP 在加载前需要模型的分片信息以支持重新分片。

import os

import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
from torch.distributed.checkpoint.stateful import Stateful
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
import torch.multiprocessing as mp
import torch.nn as nn

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

CHECKPOINT_DIR = "checkpoint"


class AppState(Stateful):
    """This is a useful wrapper for checkpointing the Application State. Since this object is compliant
    with the Stateful protocol, DCP will automatically call state_dict/load_stat_dict as needed in the
    dcp.save/load APIs.

    Note: We take advantage of this wrapper to hande calling distributed state dict methods on the model
    and optimizer.
    """

    def __init__(self, model, optimizer=None):
        self.model = model
        self.optimizer = optimizer

    def state_dict(self):
        # this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT
        model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer)
        return {
            "model": model_state_dict,
            "optim": optimizer_state_dict
        }

    def load_state_dict(self, state_dict):
        # sets our state dicts on the model and optimizer, now that we've loaded
        set_state_dict(
            self.model,
            self.optimizer,
            model_state_dict=state_dict["model"],
            optim_state_dict=state_dict["optim"]
        )

class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(16, 16)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(16, 8)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))


def setup(rank, world_size):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12355 "

    # initialize the process group
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)


def cleanup():
    dist.destroy_process_group()


def run_fsdp_checkpoint_load_example(rank, world_size):
    print(f"Running basic FSDP checkpoint loading example on rank {rank}.")
    setup(rank, world_size)

    # create a model and move it to GPU with id rank
    model = ToyModel().to(rank)
    model = FSDP(model)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.1)

    state_dict = { "app": AppState(model, optimizer)}
    dcp.load(
        state_dict=state_dict,
        checkpoint_id=CHECKPOINT_DIR,
    )

    cleanup()


if __name__ == "__main__":
    world_size = torch.cuda.device_count()
    print(f"Running fsdp checkpoint example on {world_size} devices.")
    mp.spawn(
        run_fsdp_checkpoint_load_example,
        args=(world_size,),
        nprocs=world_size,
        join=True,
    )

如果你想在非分布式设置中将保存的检查点加载到非 FSDP 封装的模型中(例如用于推理),你也可以使用 DCP 来实现。默认情况下,DCP 以 Single Program Multiple Data(SPMD) 风格保存和加载分布式 state_dict。但是,如果没有初始化进程组,DCP 会推断意图是以“非分布式”风格保存或加载,即完全在当前进程中进行。

注意

对 Multi-Program Multi-Data 的分布式检查点支持仍在开发中。

import os

import torch
import torch.distributed.checkpoint as dcp
import torch.nn as nn


CHECKPOINT_DIR = "checkpoint"


class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(16, 16)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(16, 8)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))


def run_checkpoint_load_example():
    # create the non FSDP-wrapped toy model
    model = ToyModel()
    state_dict = {
        "model": model.state_dict(),
    }

    # since no progress group is initialized, DCP will disable any collectives.
    dcp.load(
        state_dict=state_dict,
        checkpoint_id=CHECKPOINT_DIR,
    )
    model.load_state_dict(state_dict["model"])

if __name__ == "__main__":
    print(f"Running basic DCP checkpoint loading example.")
    run_checkpoint_load_example()

格式

一个尚未提及的缺点是,DCP 保存检查点的格式与使用 torch.save 生成的格式固有不同。对于习惯使用 torch.save 格式的用户来说,这可能是一个问题,或者通常只是希望为应用程序添加格式灵活性。针对这种情况,我们在 torch.distributed.checkpoint.format_utils 中提供了 format_utils 模块。

为方便用户,提供了一个命令行工具,格式如下:

python -m torch.distributed.checkpoint.format_utils <mode> <checkpoint location> <location to write formats to>

在上面的命令中,modetorch_to_dcpdcp_to_torch 之一。

此外,还为希望直接转换检查点的用户提供了方法。

import os

import torch
import torch.distributed.checkpoint as DCP
from torch.distributed.checkpoint.format_utils import dcp_to_torch_save, torch_save_to_dcp

CHECKPOINT_DIR = "checkpoint"
TORCH_SAVE_CHECKPOINT_DIR = "torch_save_checkpoint.pth"

# convert dcp model to torch.save (assumes checkpoint was generated as above)
dcp_to_torch_save(CHECKPOINT_DIR, TORCH_SAVE_CHECKPOINT_DIR)

# converts the torch.save model back to DCP
dcp_to_torch_save(TORCH_SAVE_CHECKPOINT_DIR, f"{CHECKPOINT_DIR}_new")

结论

总之,我们学习了如何使用 DCP 的 save()load() API,以及它们与 torch.save()torch.load() 的不同之处。此外,我们还学习了如何使用 get_state_dict()set_state_dict() 在生成和加载 state dict 期间自动管理特定于并行性的 FQN 和默认值。

欲了解更多信息,请参阅以下内容:

文档

访问 PyTorch 全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得问题解答

查看资源