分布式检查点 (DCP) 入门¶
**作者**:Iris Zhang、Rodrigo Kumpera、Chien-Chin Huang、Lucas Pasqualin
注意
在 github 上查看和编辑本教程。
先决条件
在分布式训练期间检查点 AI 模型可能具有挑战性,因为参数和梯度在训练器之间进行分区,并且在恢复训练时可用的训练器数量可能会发生变化。Pytorch 分布式检查点 (DCP) 可以帮助简化此过程。
在本教程中,我们将展示如何使用 DCP API 与一个简单的 FSDP 包装模型。
DCP 的工作原理¶
torch.distributed.checkpoint()
支持并行地从多个进程保存和加载模型。您可以使用此模块并行地保存在任意数量的进程上,然后在加载时跨不同的集群拓扑重新分片。
此外,通过使用 torch.distributed.checkpoint.state_dict()
中的模块,DCP 提供了对在分布式环境中优雅地处理 state_dict
生成和加载的支持。这包括管理跨模型和优化器的完全限定名称 (FQN) 映射,以及为 PyTorch 提供的并行性设置默认参数。
DCP 在几个重要方面不同于 torch.save()
和 torch.load()
它为每个检查点生成多个文件,每个进程至少生成一个。
它是在原位操作,这意味着模型应该首先分配其数据,并且 DCP 使用该存储而不是重新分配。
DCP 提供了对有状态对象(在 torch.distributed.checkpoint.stateful 中正式定义)的特殊处理,如果定义了 state_dict 和 load_state_dict 方法,则会自动调用它们。
注意
本教程中的代码在一个 8 个 GPU 的服务器上运行,但它可以很容易地推广到其他环境。
如何使用 DCP¶
这里我们使用一个用 FSDP 包装的玩具模型进行演示。类似地,这些 API 和逻辑可以应用于更大的模型进行检查点保存。
保存¶
现在,让我们创建一个玩具模块,用 FSDP 包装它,用一些虚拟输入数据馈送它,并保存它。
import os
import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
import torch.multiprocessing as mp
import torch.nn as nn
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
from torch.distributed.checkpoint.stateful import Stateful
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
CHECKPOINT_DIR = "checkpoint"
class AppState(Stateful):
"""This is a useful wrapper for checkpointing the Application State. Since this object is compliant
with the Stateful protocol, DCP will automatically call state_dict/load_stat_dict as needed in the
dcp.save/load APIs.
Note: We take advantage of this wrapper to hande calling distributed state dict methods on the model
and optimizer.
"""
def __init__(self, model, optimizer=None):
self.model = model
self.optimizer = optimizer
def state_dict(self):
# this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT
model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer)
return {
"model": model_state_dict,
"optim": optimizer_state_dict
}
def load_state_dict(self, state_dict):
# sets our state dicts on the model and optimizer, now that we've loaded
set_state_dict(
self.model,
self.optimizer,
model_state_dict=state_dict["model"],
optim_state_dict=state_dict["optim"]
)
class ToyModel(nn.Module):
def __init__(self):
super(ToyModel, self).__init__()
self.net1 = nn.Linear(16, 16)
self.relu = nn.ReLU()
self.net2 = nn.Linear(16, 8)
def forward(self, x):
return self.net2(self.relu(self.net1(x)))
def setup(rank, world_size):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355 "
# initialize the process group
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
def cleanup():
dist.destroy_process_group()
def run_fsdp_checkpoint_save_example(rank, world_size):
print(f"Running basic FSDP checkpoint saving example on rank {rank}.")
setup(rank, world_size)
# create a model and move it to GPU with id rank
model = ToyModel().to(rank)
model = FSDP(model)
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
optimizer.zero_grad()
model(torch.rand(8, 16, device="cuda")).sum().backward()
optimizer.step()
state_dict = { "app": AppState(model, optimizer) }
dcp.save(state_dict, checkpoint_id=CHECKPOINT_DIR)
cleanup()
if __name__ == "__main__":
world_size = torch.cuda.device_count()
print(f"Running fsdp checkpoint example on {world_size} devices.")
mp.spawn(
run_fsdp_checkpoint_save_example,
args=(world_size,),
nprocs=world_size,
join=True,
)
请继续检查 checkpoint 目录。您应该会看到如下所示的 8 个检查点文件。
加载¶
保存后,让我们创建相同的 FSDP 包装模型,并将保存的状态字典从存储加载到模型中。您可以使用相同的 world size 或不同的 world size 加载。
请注意,您必须在加载之前调用 model.state_dict()
并将其传递给 DCP 的 load_state_dict()
API。这与 torch.load()
从根本上不同,因为 torch.load()
只需要加载前检查点的路径。我们需要加载前 state_dict
的原因是
DCP 使用模型 state_dict 中预先分配的存储从检查点目录加载。在加载过程中,传入的 state_dict 将被就地更新。
DCP 需要加载前模型的分片信息来支持重新分片。
import os
import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
from torch.distributed.checkpoint.stateful import Stateful
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
import torch.multiprocessing as mp
import torch.nn as nn
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
CHECKPOINT_DIR = "checkpoint"
class AppState(Stateful):
"""This is a useful wrapper for checkpointing the Application State. Since this object is compliant
with the Stateful protocol, DCP will automatically call state_dict/load_stat_dict as needed in the
dcp.save/load APIs.
Note: We take advantage of this wrapper to hande calling distributed state dict methods on the model
and optimizer.
"""
def __init__(self, model, optimizer=None):
self.model = model
self.optimizer = optimizer
def state_dict(self):
# this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT
model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer)
return {
"model": model_state_dict,
"optim": optimizer_state_dict
}
def load_state_dict(self, state_dict):
# sets our state dicts on the model and optimizer, now that we've loaded
set_state_dict(
self.model,
self.optimizer,
model_state_dict=state_dict["model"],
optim_state_dict=state_dict["optim"]
)
class ToyModel(nn.Module):
def __init__(self):
super(ToyModel, self).__init__()
self.net1 = nn.Linear(16, 16)
self.relu = nn.ReLU()
self.net2 = nn.Linear(16, 8)
def forward(self, x):
return self.net2(self.relu(self.net1(x)))
def setup(rank, world_size):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355 "
# initialize the process group
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
def cleanup():
dist.destroy_process_group()
def run_fsdp_checkpoint_load_example(rank, world_size):
print(f"Running basic FSDP checkpoint loading example on rank {rank}.")
setup(rank, world_size)
# create a model and move it to GPU with id rank
model = ToyModel().to(rank)
model = FSDP(model)
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
state_dict = { "app": AppState(model, optimizer)}
dcp.load(
state_dict=state_dict,
checkpoint_id=CHECKPOINT_DIR,
)
cleanup()
if __name__ == "__main__":
world_size = torch.cuda.device_count()
print(f"Running fsdp checkpoint example on {world_size} devices.")
mp.spawn(
run_fsdp_checkpoint_load_example,
args=(world_size,),
nprocs=world_size,
join=True,
)
如果您想将保存的检查点加载到非 FSDP 包装模型(例如用于推理)的非分布式设置中,您也可以使用 DCP 执行此操作。默认情况下,DCP 以单程序多数据 (SPMD) 样式保存和加载分布式 state_dict
。但是,如果未初始化进程组,则 DCP 会推断出意图是以“非分布式”样式保存或加载,这意味着完全在当前进程中。
注意
多程序多数据分布式检查点支持仍在开发中。
import os
import torch
import torch.distributed.checkpoint as dcp
import torch.nn as nn
CHECKPOINT_DIR = "checkpoint"
class ToyModel(nn.Module):
def __init__(self):
super(ToyModel, self).__init__()
self.net1 = nn.Linear(16, 16)
self.relu = nn.ReLU()
self.net2 = nn.Linear(16, 8)
def forward(self, x):
return self.net2(self.relu(self.net1(x)))
def run_checkpoint_load_example():
# create the non FSDP-wrapped toy model
model = ToyModel()
state_dict = {
"model": model.state_dict(),
}
# since no progress group is initialized, DCP will disable any collectives.
dcp.load(
state_dict=state_dict,
checkpoint_id=CHECKPOINT_DIR,
)
model.load_state_dict(state_dict["model"])
if __name__ == "__main__":
print(f"Running basic DCP checkpoint loading example.")
run_checkpoint_load_example()
格式¶
尚未提及的一个缺点是,DCP 以与使用 torch.save 生成的检查点本质上不同的格式保存检查点。当用户希望与习惯使用 torch.save 格式的用户共享模型时,或者通常只是希望为其应用程序添加格式灵活性时,这可能是一个问题。对于这种情况,我们在 torch.distributed.checkpoint.format_utils
中提供了 format_utils
模块。
为了方便用户,提供了一个命令行实用程序,其格式如下
python -m torch.distributed.checkpoint.format_utils -m <checkpoint location> <location to write formats to> <mode>
在上面的命令中,mode
是 torch_to_dcp`
或 dcp_to_torch
之一。
或者,也为希望直接转换检查点用户提供了方法。
import os
import torch
import torch.distributed.checkpoint as DCP
from torch.distributed.checkpoint.format_utils import dcp_to_torch_save, torch_save_to_dcp
CHECKPOINT_DIR = "checkpoint"
TORCH_SAVE_CHECKPOINT_DIR = "torch_save_checkpoint.pth"
# convert dcp model to torch.save (assumes checkpoint was generated as above)
dcp_to_torch_save(CHECKPOINT_DIR, TORCH_SAVE_CHECKPOINT_DIR)
# converts the torch.save model back to DCP
dcp_to_torch_save(TORCH_SAVE_CHECKPOINT_DIR, f"{CHECKPOINT_DIR}_new")
结论¶
总之,我们学习了如何使用 DCP 的 save()
和 load()
API,以及它们与 torch.save()
和 torch.load()
的区别。此外,我们还学习了如何使用 get_state_dict()
和 set_state_dict()
在状态字典生成和加载期间自动管理特定于并行的 FQN 和默认值。
更多信息,请参见以下内容