使用分布式检查点 (DCP) 进行异步保存¶
创建日期:2024 年 7 月 22 日 | 最后更新:2024 年 7 月 22 日 | 最后验证:2024 年 11 月 5 日
作者: Lucas Pasqualin, Iris Zhang, Rodrigo Kumpera, Chien-Chin Huang
检查点保存 (Checkpointing) 通常是分布式训练工作负载关键路径中的瓶颈,随着模型和世界规模的增长,其开销越来越大。缓解这一开销的一个绝佳策略是异步并行地保存检查点。下面,我们扩展了来自《分布式检查点入门教程》的保存示例,以展示如何轻松集成 torch.distributed.checkpoint.async_save
。
如何使用 DCP 并行生成检查点
有效的性能优化策略
PyTorch v2.4.0 或更高版本
异步检查点概述¶
在开始使用异步检查点之前,了解它与同步检查点相比的差异和局限性非常重要。具体来说
- 内存要求 - 异步检查点的工作方式是首先将模型复制到内部 CPU 缓冲区中。
这很有帮助,因为它确保模型和优化器权重在检查点保存完成前不会改变,但这会增加 CPU 内存,增加量为
checkpoint_size_per_rank X number_of_ranks
。此外,用户应注意了解其系统的内存限制。具体来说,Pinned memory (锁定内存) 意味着使用page-lock
内存,这比pageable
内存稀缺。
- 检查点管理 - 由于检查点是异步的,用户需要自行管理同时运行的检查点。一般来说,用户可以
通过处理
async_save
返回的 Future 对象来采用自己的管理策略。对于大多数用户,我们建议将同时运行的检查点限制为一个异步请求,以避免每个请求带来额外的内存压力。
import os
import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
import torch.multiprocessing as mp
import torch.nn as nn
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
from torch.distributed.checkpoint.stateful import Stateful
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
CHECKPOINT_DIR = "checkpoint"
class AppState(Stateful):
"""This is a useful wrapper for checkpointing the Application State. Since this object is compliant
with the Stateful protocol, DCP will automatically call state_dict/load_stat_dict as needed in the
dcp.save/load APIs.
Note: We take advantage of this wrapper to hande calling distributed state dict methods on the model
and optimizer.
"""
def __init__(self, model, optimizer=None):
self.model = model
self.optimizer = optimizer
def state_dict(self):
# this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT
model_state_dict, optimizer_state_dict = get_state_dict(model, optimizer)
return {
"model": model_state_dict,
"optim": optimizer_state_dict
}
def load_state_dict(self, state_dict):
# sets our state dicts on the model and optimizer, now that we've loaded
set_state_dict(
self.model,
self.optimizer,
model_state_dict=state_dict["model"],
optim_state_dict=state_dict["optim"]
)
class ToyModel(nn.Module):
def __init__(self):
super(ToyModel, self).__init__()
self.net1 = nn.Linear(16, 16)
self.relu = nn.ReLU()
self.net2 = nn.Linear(16, 8)
def forward(self, x):
return self.net2(self.relu(self.net1(x)))
def setup(rank, world_size):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355 "
# initialize the process group
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
def cleanup():
dist.destroy_process_group()
def run_fsdp_checkpoint_save_example(rank, world_size):
print(f"Running basic FSDP checkpoint saving example on rank {rank}.")
setup(rank, world_size)
# create a model and move it to GPU with id rank
model = ToyModel().to(rank)
model = FSDP(model)
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
checkpoint_future = None
for step in range(10):
optimizer.zero_grad()
model(torch.rand(8, 16, device="cuda")).sum().backward()
optimizer.step()
# waits for checkpointing to finish if one exists, avoiding queuing more then one checkpoint request at a time
if checkpoint_future is not None:
checkpoint_future.result()
state_dict = { "app": AppState(model, optimizer) }
checkpoint_future = dcp.async_save(state_dict, checkpoint_id=f"{CHECKPOINT_DIR}_step{step}")
cleanup()
if __name__ == "__main__":
world_size = torch.cuda.device_count()
print(f"Running async checkpoint example on {world_size} devices.")
mp.spawn(
run_fsdp_checkpoint_save_example,
args=(world_size,),
nprocs=world_size,
join=True,
)
使用 Pinned Memory 进一步提升性能¶
如果上述优化仍然不够高效,您可以利用 GPU 模型的额外优化,该优化使用 Pinned Memory (锁定内存) 缓冲区进行检查点暂存。具体来说,此优化解决了异步检查点的主要开销,即内存中的复制到检查点缓冲区。通过在检查点请求之间维护 Pinned Memory 缓冲区,用户可以利用直接内存访问来加快此复制过程。
注意
此优化的主要缺点是缓冲区在检查点步骤之间持续存在。如上所示,如果不使用 Pinned Memory 优化,任何检查点缓冲区在检查点保存完成后就会立即释放。而使用 Pinned Memory 实现时,此缓冲区会在步骤之间保持,导致相同的峰值内存压力贯穿应用程序生命周期。
import os
import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
import torch.multiprocessing as mp
import torch.nn as nn
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
from torch.distributed.checkpoint.stateful import Stateful
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
from torch.distributed.checkpoint import StorageWriter
CHECKPOINT_DIR = "checkpoint"
class AppState(Stateful):
"""This is a useful wrapper for checkpointing the Application State. Since this object is compliant
with the Stateful protocol, DCP will automatically call state_dict/load_stat_dict as needed in the
dcp.save/load APIs.
Note: We take advantage of this wrapper to hande calling distributed state dict methods on the model
and optimizer.
"""
def __init__(self, model, optimizer=None):
self.model = model
self.optimizer = optimizer
def state_dict(self):
# this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT
model_state_dict, optimizer_state_dict = get_state_dict(model, optimizer)
return {
"model": model_state_dict,
"optim": optimizer_state_dict
}
def load_state_dict(self, state_dict):
# sets our state dicts on the model and optimizer, now that we've loaded
set_state_dict(
self.model,
self.optimizer,
model_state_dict=state_dict["model"],
optim_state_dict=state_dict["optim"]
)
class ToyModel(nn.Module):
def __init__(self):
super(ToyModel, self).__init__()
self.net1 = nn.Linear(16, 16)
self.relu = nn.ReLU()
self.net2 = nn.Linear(16, 8)
def forward(self, x):
return self.net2(self.relu(self.net1(x)))
def setup(rank, world_size):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355 "
# initialize the process group
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
def cleanup():
dist.destroy_process_group()
def run_fsdp_checkpoint_save_example(rank, world_size):
print(f"Running basic FSDP checkpoint saving example on rank {rank}.")
setup(rank, world_size)
# create a model and move it to GPU with id rank
model = ToyModel().to(rank)
model = FSDP(model)
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
# The storage writer defines our 'staging' strategy, where staging is considered the process of copying
# checkpoints to in-memory buffers. By setting `cached_state_dict=True`, we enable efficient memory copying
# into a persistent buffer with pinned memory enabled.
# Note: It's important that the writer persists in between checkpointing requests, since it maintains the
# pinned memory buffer.
writer = StorageWriter(cached_state_dict=True)
checkpoint_future = None
for step in range(10):
optimizer.zero_grad()
model(torch.rand(8, 16, device="cuda")).sum().backward()
optimizer.step()
state_dict = { "app": AppState(model, optimizer) }
if checkpoint_future is not None:
# waits for checkpointing to finish, avoiding queuing more then one checkpoint request at a time
checkpoint_future.result()
dcp.async_save(state_dict, storage_writer=writer, checkpoint_id=f"{CHECKPOINT_DIR}_step{step}")
cleanup()
if __name__ == "__main__":
world_size = torch.cuda.device_count()
print(f"Running fsdp checkpoint example on {world_size} devices.")
mp.spawn(
run_fsdp_checkpoint_save_example,
args=(world_size,),
nprocs=world_size,
join=True,
)
结论¶
总而言之,我们学习了如何使用 DCP 的 async_save()
API 在关键训练路径之外生成检查点。我们还了解了使用此 API 带来的额外内存和并发开销,以及利用 Pinned Memory 进一步加速的其他优化。