• 教程 >
  • 使用分布式检查点 (DCP) 进行异步保存
快捷方式

使用分布式检查点 (DCP) 进行异步保存

创建日期:2024 年 7 月 22 日 | 最后更新:2024 年 7 月 22 日 | 最后验证:2024 年 11 月 5 日

作者: Lucas Pasqualin, Iris Zhang, Rodrigo Kumpera, Chien-Chin Huang

检查点保存 (Checkpointing) 通常是分布式训练工作负载关键路径中的瓶颈,随着模型和世界规模的增长,其开销越来越大。缓解这一开销的一个绝佳策略是异步并行地保存检查点。下面,我们扩展了来自《分布式检查点入门教程》的保存示例,以展示如何轻松集成 torch.distributed.checkpoint.async_save

您将学到什么
  • 如何使用 DCP 并行生成检查点

  • 有效的性能优化策略

前提条件

异步检查点概述

在开始使用异步检查点之前,了解它与同步检查点相比的差异和局限性非常重要。具体来说

  • 内存要求 - 异步检查点的工作方式是首先将模型复制到内部 CPU 缓冲区中。

    这很有帮助,因为它确保模型和优化器权重在检查点保存完成前不会改变,但这会增加 CPU 内存,增加量为 checkpoint_size_per_rank X number_of_ranks。此外,用户应注意了解其系统的内存限制。具体来说,Pinned memory (锁定内存) 意味着使用 page-lock 内存,这比 pageable 内存稀缺。

  • 检查点管理 - 由于检查点是异步的,用户需要自行管理同时运行的检查点。一般来说,用户可以

    通过处理 async_save 返回的 Future 对象来采用自己的管理策略。对于大多数用户,我们建议将同时运行的检查点限制为一个异步请求,以避免每个请求带来额外的内存压力。

import os

import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
import torch.multiprocessing as mp
import torch.nn as nn

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
from torch.distributed.checkpoint.stateful import Stateful
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType

CHECKPOINT_DIR = "checkpoint"


class AppState(Stateful):
    """This is a useful wrapper for checkpointing the Application State. Since this object is compliant
    with the Stateful protocol, DCP will automatically call state_dict/load_stat_dict as needed in the
    dcp.save/load APIs.

    Note: We take advantage of this wrapper to hande calling distributed state dict methods on the model
    and optimizer.
    """

    def __init__(self, model, optimizer=None):
        self.model = model
        self.optimizer = optimizer

    def state_dict(self):
        # this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT
        model_state_dict, optimizer_state_dict = get_state_dict(model, optimizer)
        return {
            "model": model_state_dict,
            "optim": optimizer_state_dict
        }

    def load_state_dict(self, state_dict):
        # sets our state dicts on the model and optimizer, now that we've loaded
        set_state_dict(
            self.model,
            self.optimizer,
            model_state_dict=state_dict["model"],
            optim_state_dict=state_dict["optim"]
        )

class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(16, 16)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(16, 8)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))


def setup(rank, world_size):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12355 "

    # initialize the process group
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)


def cleanup():
    dist.destroy_process_group()


def run_fsdp_checkpoint_save_example(rank, world_size):
    print(f"Running basic FSDP checkpoint saving example on rank {rank}.")
    setup(rank, world_size)

    # create a model and move it to GPU with id rank
    model = ToyModel().to(rank)
    model = FSDP(model)

    loss_fn = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.1)

    checkpoint_future = None
    for step in range(10):
        optimizer.zero_grad()
        model(torch.rand(8, 16, device="cuda")).sum().backward()
        optimizer.step()

        # waits for checkpointing to finish if one exists, avoiding queuing more then one checkpoint request at a time
        if checkpoint_future is not None:
            checkpoint_future.result()

        state_dict = { "app": AppState(model, optimizer) }
        checkpoint_future = dcp.async_save(state_dict, checkpoint_id=f"{CHECKPOINT_DIR}_step{step}")

    cleanup()


if __name__ == "__main__":
    world_size = torch.cuda.device_count()
    print(f"Running async checkpoint example on {world_size} devices.")
    mp.spawn(
        run_fsdp_checkpoint_save_example,
        args=(world_size,),
        nprocs=world_size,
        join=True,
    )

使用 Pinned Memory 进一步提升性能

如果上述优化仍然不够高效,您可以利用 GPU 模型的额外优化,该优化使用 Pinned Memory (锁定内存) 缓冲区进行检查点暂存。具体来说,此优化解决了异步检查点的主要开销,即内存中的复制到检查点缓冲区。通过在检查点请求之间维护 Pinned Memory 缓冲区,用户可以利用直接内存访问来加快此复制过程。

注意

此优化的主要缺点是缓冲区在检查点步骤之间持续存在。如上所示,如果不使用 Pinned Memory 优化,任何检查点缓冲区在检查点保存完成后就会立即释放。而使用 Pinned Memory 实现时,此缓冲区会在步骤之间保持,导致相同的峰值内存压力贯穿应用程序生命周期。

import os

import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
import torch.multiprocessing as mp
import torch.nn as nn

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
from torch.distributed.checkpoint.stateful import Stateful
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
from torch.distributed.checkpoint import StorageWriter

CHECKPOINT_DIR = "checkpoint"


class AppState(Stateful):
    """This is a useful wrapper for checkpointing the Application State. Since this object is compliant
    with the Stateful protocol, DCP will automatically call state_dict/load_stat_dict as needed in the
    dcp.save/load APIs.

    Note: We take advantage of this wrapper to hande calling distributed state dict methods on the model
    and optimizer.
    """

    def __init__(self, model, optimizer=None):
        self.model = model
        self.optimizer = optimizer

    def state_dict(self):
        # this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT
        model_state_dict, optimizer_state_dict = get_state_dict(model, optimizer)
        return {
            "model": model_state_dict,
            "optim": optimizer_state_dict
        }

    def load_state_dict(self, state_dict):
        # sets our state dicts on the model and optimizer, now that we've loaded
        set_state_dict(
            self.model,
            self.optimizer,
            model_state_dict=state_dict["model"],
            optim_state_dict=state_dict["optim"]
        )

class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(16, 16)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(16, 8)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))


def setup(rank, world_size):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12355 "

    # initialize the process group
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)


def cleanup():
    dist.destroy_process_group()


def run_fsdp_checkpoint_save_example(rank, world_size):
    print(f"Running basic FSDP checkpoint saving example on rank {rank}.")
    setup(rank, world_size)

    # create a model and move it to GPU with id rank
    model = ToyModel().to(rank)
    model = FSDP(model)

    loss_fn = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.1)

    # The storage writer defines our 'staging' strategy, where staging is considered the process of copying
    # checkpoints to in-memory buffers. By setting `cached_state_dict=True`, we enable efficient memory copying
    # into a persistent buffer with pinned memory enabled.
    # Note: It's important that the writer persists in between checkpointing requests, since it maintains the
    # pinned memory buffer.
    writer = StorageWriter(cached_state_dict=True)
    checkpoint_future = None
    for step in range(10):
        optimizer.zero_grad()
        model(torch.rand(8, 16, device="cuda")).sum().backward()
        optimizer.step()

        state_dict = { "app": AppState(model, optimizer) }
        if checkpoint_future is not None:
            # waits for checkpointing to finish, avoiding queuing more then one checkpoint request at a time
            checkpoint_future.result()
        dcp.async_save(state_dict, storage_writer=writer, checkpoint_id=f"{CHECKPOINT_DIR}_step{step}")

    cleanup()


if __name__ == "__main__":
    world_size = torch.cuda.device_count()
    print(f"Running fsdp checkpoint example on {world_size} devices.")
    mp.spawn(
        run_fsdp_checkpoint_save_example,
        args=(world_size,),
        nprocs=world_size,
        join=True,
    )

结论

总而言之,我们学习了如何使用 DCP 的 async_save() API 在关键训练路径之外生成检查点。我们还了解了使用此 API 带来的额外内存和并发开销,以及利用 Pinned Memory 进一步加速的其他优化。


评价本教程

© 版权所有 2024,PyTorch。

使用 Sphinx 构建,主题由 Read the Docs 提供。

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

获取针对初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得问题解答

查看资源