• 教程 >
  • 分布式检查点 (DCP) 入门
快捷方式

分布式检查点 (DCP) 入门

创建于:2023 年 10 月 02 日 | 最后更新:2024 年 10 月 30 日 | 最后验证:2024 年 11 月 05 日

作者: Iris Zhang, Rodrigo Kumpera, Chien-Chin Huang, Lucas Pasqualin

注意

editgithub 中查看和编辑本教程。

先决条件

在分布式训练期间检查点 AI 模型可能具有挑战性,因为参数和梯度在训练器之间进行分区,并且当您恢复训练时,可用的训练器数量可能会发生变化。Pytorch 分布式检查点 (DCP) 可以帮助简化此过程。

在本教程中,我们将展示如何将 DCP API 与简单的 FSDP 包装模型一起使用。

DCP 的工作原理

torch.distributed.checkpoint() 支持从多个并行 ranks 保存和加载模型。您可以使用此模块在任意数量的 ranks 中并行保存,然后在加载时在不同的集群拓扑中重新分片。

此外,通过使用 torch.distributed.checkpoint.state_dict() 中的模块,DCP 为在分布式设置中优雅地处理 state_dict 生成和加载提供了支持。这包括管理跨模型和优化器的完全限定名称 (FQN) 映射,以及为 PyTorch 提供的并行性设置默认参数。

DCP 与 torch.save()torch.load() 在几个重要方面有所不同

  • 它为每个检查点生成多个文件,每个 rank 至少一个文件。

  • 它就地操作,这意味着模型应首先分配其数据,而 DCP 使用该存储。

  • DCP 为有状态对象(在 torch.distributed.checkpoint.stateful 中正式定义)提供特殊处理,如果定义了 state_dictload_state_dict 方法,则自动调用它们。

注意

本教程中的代码在 8-GPU 服务器上运行,但可以轻松推广到其他环境。

如何使用 DCP

在此,我们使用 FSDP 包装的玩具模型进行演示。同样,API 和逻辑可以应用于更大的模型进行检查点。

保存

现在,让我们创建一个玩具模块,用 FSDP 包装它,用一些虚拟输入数据馈送它,然后保存它。

import os

import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
import torch.multiprocessing as mp
import torch.nn as nn

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
from torch.distributed.checkpoint.stateful import Stateful
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType

CHECKPOINT_DIR = "checkpoint"


class AppState(Stateful):
    """This is a useful wrapper for checkpointing the Application State. Since this object is compliant
    with the Stateful protocol, DCP will automatically call state_dict/load_stat_dict as needed in the
    dcp.save/load APIs.

    Note: We take advantage of this wrapper to hande calling distributed state dict methods on the model
    and optimizer.
    """

    def __init__(self, model, optimizer=None):
        self.model = model
        self.optimizer = optimizer

    def state_dict(self):
        # this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT
        model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer)
        return {
            "model": model_state_dict,
            "optim": optimizer_state_dict
        }

    def load_state_dict(self, state_dict):
        # sets our state dicts on the model and optimizer, now that we've loaded
        set_state_dict(
            self.model,
            self.optimizer,
            model_state_dict=state_dict["model"],
            optim_state_dict=state_dict["optim"]
        )

class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(16, 16)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(16, 8)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))


def setup(rank, world_size):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12355 "

    # initialize the process group
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)


def cleanup():
    dist.destroy_process_group()


def run_fsdp_checkpoint_save_example(rank, world_size):
    print(f"Running basic FSDP checkpoint saving example on rank {rank}.")
    setup(rank, world_size)

    # create a model and move it to GPU with id rank
    model = ToyModel().to(rank)
    model = FSDP(model)

    loss_fn = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.1)

    optimizer.zero_grad()
    model(torch.rand(8, 16, device="cuda")).sum().backward()
    optimizer.step()

    state_dict = { "app": AppState(model, optimizer) }
    dcp.save(state_dict, checkpoint_id=CHECKPOINT_DIR)

    cleanup()


if __name__ == "__main__":
    world_size = torch.cuda.device_count()
    print(f"Running fsdp checkpoint example on {world_size} devices.")
    mp.spawn(
        run_fsdp_checkpoint_save_example,
        args=(world_size,),
        nprocs=world_size,
        join=True,
    )

请继续检查 checkpoint 目录。您应该看到 8 个检查点文件,如下所示。

Distributed Checkpoint

加载

保存后,让我们创建相同的 FSDP 包装模型,并将保存的状态字典从存储加载到模型中。您可以在相同的 world size 或不同的 world size 中加载。

请注意,您必须在加载之前调用 model.state_dict() 并将其传递给 DCP 的 load_state_dict() API。这与 torch.load() 根本不同,因为 torch.load() 仅需要在加载之前提供检查点的路径。我们需要 state_dict 在加载之前的原因是

  • DCP 使用来自模型 state_dict 的预分配存储从检查点目录加载。在加载期间,传入的 state_dict 将就地更新。

  • DCP 需要来自模型的分片信息,以便在加载之前支持重新分片。

import os

import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
from torch.distributed.checkpoint.stateful import Stateful
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
import torch.multiprocessing as mp
import torch.nn as nn

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

CHECKPOINT_DIR = "checkpoint"


class AppState(Stateful):
    """This is a useful wrapper for checkpointing the Application State. Since this object is compliant
    with the Stateful protocol, DCP will automatically call state_dict/load_stat_dict as needed in the
    dcp.save/load APIs.

    Note: We take advantage of this wrapper to hande calling distributed state dict methods on the model
    and optimizer.
    """

    def __init__(self, model, optimizer=None):
        self.model = model
        self.optimizer = optimizer

    def state_dict(self):
        # this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT
        model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer)
        return {
            "model": model_state_dict,
            "optim": optimizer_state_dict
        }

    def load_state_dict(self, state_dict):
        # sets our state dicts on the model and optimizer, now that we've loaded
        set_state_dict(
            self.model,
            self.optimizer,
            model_state_dict=state_dict["model"],
            optim_state_dict=state_dict["optim"]
        )

class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(16, 16)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(16, 8)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))


def setup(rank, world_size):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12355 "

    # initialize the process group
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)


def cleanup():
    dist.destroy_process_group()


def run_fsdp_checkpoint_load_example(rank, world_size):
    print(f"Running basic FSDP checkpoint loading example on rank {rank}.")
    setup(rank, world_size)

    # create a model and move it to GPU with id rank
    model = ToyModel().to(rank)
    model = FSDP(model)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.1)

    state_dict = { "app": AppState(model, optimizer)}
    dcp.load(
        state_dict=state_dict,
        checkpoint_id=CHECKPOINT_DIR,
    )

    cleanup()


if __name__ == "__main__":
    world_size = torch.cuda.device_count()
    print(f"Running fsdp checkpoint example on {world_size} devices.")
    mp.spawn(
        run_fsdp_checkpoint_load_example,
        args=(world_size,),
        nprocs=world_size,
        join=True,
    )

如果您想将保存的检查点加载到非分布式设置中的非 FSDP 包装模型中(可能是为了推理),您也可以使用 DCP 执行此操作。默认情况下,DCP 以单程序多数据 (SPMD) 样式保存和加载分布式 state_dict。但是,如果未初始化进程组,则 DCP 推断意图是以“非分布式”样式保存或加载,这意味着完全在当前进程中。

注意

对多程序多数据的分布式检查点支持仍在开发中。

import os

import torch
import torch.distributed.checkpoint as dcp
import torch.nn as nn


CHECKPOINT_DIR = "checkpoint"


class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(16, 16)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(16, 8)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))


def run_checkpoint_load_example():
    # create the non FSDP-wrapped toy model
    model = ToyModel()
    state_dict = {
        "model": model.state_dict(),
    }

    # since no progress group is initialized, DCP will disable any collectives.
    dcp.load(
        state_dict=state_dict,
        checkpoint_id=CHECKPOINT_DIR,
    )
    model.load_state_dict(state_dict["model"])

if __name__ == "__main__":
    print(f"Running basic DCP checkpoint loading example.")
    run_checkpoint_load_example()

格式

尚未提及的一个缺点是,DCP 保存检查点的格式与使用 torch.save 生成的格式本质上不同。由于当用户希望与习惯 torch.save 格式的用户共享模型时,或者通常只是想为其应用程序添加格式灵活性时,这可能会成为问题。对于这种情况,我们在 torch.distributed.checkpoint.format_utils 中提供了 format_utils 模块。

为用户方便起见,提供了一个命令行实用程序,其格式如下

python -m torch.distributed.checkpoint.format_utils <mode> <checkpoint location> <location to write formats to>

在上面的命令中,modetorch_to_dcpdcp_to_torch 之一。

或者,也为可能希望直接转换检查点的用户提供了方法。

import os

import torch
import torch.distributed.checkpoint as DCP
from torch.distributed.checkpoint.format_utils import dcp_to_torch_save, torch_save_to_dcp

CHECKPOINT_DIR = "checkpoint"
TORCH_SAVE_CHECKPOINT_DIR = "torch_save_checkpoint.pth"

# convert dcp model to torch.save (assumes checkpoint was generated as above)
dcp_to_torch_save(CHECKPOINT_DIR, TORCH_SAVE_CHECKPOINT_DIR)

# converts the torch.save model back to DCP
dcp_to_torch_save(TORCH_SAVE_CHECKPOINT_DIR, f"{CHECKPOINT_DIR}_new")

结论

总之,我们学习了如何使用 DCP 的 save()load() API,以及它们与 torch.save()torch.load() 的不同之处。此外,我们还学习了如何在状态字典生成和加载期间使用 get_state_dict()set_state_dict() 来自动管理特定于并行性的 FQN 和默认值。

有关更多信息,请参阅以下内容


评价本教程

© 版权所有 2024, PyTorch。

使用 Sphinx 构建,主题由 theme 提供,并由 Read the Docs 提供。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得解答

查看资源