使用全分片数据并行 (FSDP) 进行高级模型训练¶
作者:Hamid Shojanazeri、Less Wright、Rohan Varma、赵艳丽
本教程介绍了 PyTorch 1.12 版本中全分片数据并行 (FSDP) 的更多高级功能。要熟悉 FSDP,请参阅FSDP 入门教程。
在本教程中,我们使用 FSDP 对 HuggingFace (HF) T5 模型进行微调以进行文本摘要,作为工作示例。
此示例使用 Wikihow,为简单起见,我们将展示在单个节点 P4dn 实例上使用 8 个 A100 GPU 进行的训练。我们很快将在 PyTorch 的 Medium 频道发布一篇关于在多节点集群上进行大规模 FSDP 训练的博文,敬请期待。
FSDP 是一个生产就绪的软件包,专注于易用性、性能和长期支持。FSDP 的主要优势之一是减少每个 GPU 上的内存占用。这使得能够使用比 DDP 更少的总内存来训练更大的模型,并利用计算和通信的重叠来有效地训练模型。这种降低的内存压力可以用来训练更大的模型或增加批大小,从而可能有助于提高整体训练吞吐量。您可以此处阅读有关 PyTorch FSDP 的更多信息。
本教程中的 FSDP 功能¶
Transformer 自动包装策略
混合精度
在设备上初始化 FSDP 模型
分片策略
反向传播预取
通过流式传输到 CPU 保存模型检查点
FSDP 工作原理回顾¶
在高级别上,FDSP 的工作原理如下
在构造函数中
对模型参数进行分片,每个秩只保留其自己的分片
在正向传播中
运行 all_gather 收集所有秩的所有分片以恢复此 FSDP 单元的完整参数 运行正向传播计算
丢弃它刚刚收集的非拥有参数分片以释放内存
在反向传播中
运行 all_gather 收集所有 Rank 上的所有分片,以恢复此 FSDP 单元中的完整参数。运行反向计算。
丢弃非本机参数以释放内存。
运行 reduce_scatter 同步梯度。
微调 HF T5¶
HF T5 预训练模型有四种不同的大小,从小型的 6000 万参数到 XXL 的 110 亿参数不等。在本教程中,我们演示了使用 FSDP 对 T5 3B 进行微调,以使用 WikiHow 数据集进行文本摘要。本教程的主要重点是突出 FSDP 中不同的可用功能,这些功能有助于训练超过 3B 参数的大规模模型。此外,我们还介绍了针对 Transformer 基于模型的特定功能。本教程的代码可在 Pytorch 示例 中找到。
设置
1.1 安装 PyTorch nightly 版本
我们将安装 PyTorch nightly 版本,因为某些功能(例如激活检查点)在 nightly 版本中可用,并将添加到 1.12 之后的下一个 PyTorch 版本中。
pip3 install --pre torch torchvision torchaudio -f https://download.pytorch.org/whl/nightly/cu113/torch_nightly.html
1.2 数据集设置
请创建一个 data 文件夹,从 wikihowAll.csv 和 wikihowSep.cs 下载 WikiHow 数据集,并将它们放置在 data 文件夹中。我们将使用来自 summarization_dataset 的 wikihow 数据集。
接下来,我们将以下代码片段添加到 Python 脚本“T5_training.py”中。
注意
本教程的完整源代码可在 PyTorch 示例 中找到。
1.3 导入必要的包
import os
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from transformers import AutoTokenizer, GPT2TokenizerFast
from transformers import T5Tokenizer, T5ForConditionalGeneration
import functools
from torch.optim.lr_scheduler import StepLR
import torch.nn.functional as F
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
from transformers.models.t5.modeling_t5 import T5Block
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper,
CheckpointImpl,
apply_activation_checkpointing_wrapper)
from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
MixedPrecision,
BackwardPrefetch,
ShardingStrategy,
FullStateDictConfig,
StateDictType,
)
from torch.distributed.fsdp.wrap import (
transformer_auto_wrap_policy,
enable_wrap,
wrap,
)
from functools import partial
from torch.utils.data import DataLoader
from pathlib import Path
from summarization_dataset import *
from transformers.models.t5.modeling_t5 import T5Block
from typing import Type
import time
import tqdm
from datetime import datetime
1.4 分布式训练设置。在这里,我们使用两个辅助函数来初始化分布式训练的进程,然后在训练完成后进行清理。在本教程中,我们将使用 torch elastic,使用 torchrun,它将自动设置工作进程的 RANK 和 WORLD_SIZE。
def setup():
# initialize the process group
dist.init_process_group("nccl")
def cleanup():
dist.destroy_process_group()
2.1 设置 HuggingFace T5 模型
def setup_model(model_name):
model = T5ForConditionalGeneration.from_pretrained(model_name)
tokenizer = T5Tokenizer.from_pretrained(model_name)
return model, tokenizer
我们还在此处添加了几个辅助函数,用于日期和格式化内存指标。
def get_date_of_run():
"""create date and time for file save uniqueness
example: 2022-05-07-08:31:12_PM'
"""
date_of_run = datetime.now().strftime("%Y-%m-%d-%I:%M:%S_%p")
print(f"--> current date and time of run = {date_of_run}")
return date_of_run
def format_metrics_to_gb(item):
"""quick function to format numbers to gigabyte and round to 4 digit precision"""
metric_num = item / g_gigabyte
metric_num = round(metric_num, ndigits=4)
return metric_num
2.2 定义训练函数
def train(args, model, rank, world_size, train_loader, optimizer, epoch, sampler=None):
model.train()
local_rank = int(os.environ['LOCAL_RANK'])
fsdp_loss = torch.zeros(2).to(local_rank)
if sampler:
sampler.set_epoch(epoch)
if rank==0:
inner_pbar = tqdm.tqdm(
range(len(train_loader)), colour="blue", desc="r0 Training Epoch"
)
for batch in train_loader:
for key in batch.keys():
batch[key] = batch[key].to(local_rank)
optimizer.zero_grad()
output = model(input_ids=batch["source_ids"],attention_mask=batch["source_mask"],labels=batch["target_ids"] )
loss = output["loss"]
loss.backward()
optimizer.step()
fsdp_loss[0] += loss.item()
fsdp_loss[1] += len(batch)
if rank==0:
inner_pbar.update(1)
dist.all_reduce(fsdp_loss, op=dist.ReduceOp.SUM)
train_accuracy = fsdp_loss[0] / fsdp_loss[1]
if rank == 0:
inner_pbar.close()
print(
f"Train Epoch: \t{epoch}, Loss: \t{train_accuracy:.4f}"
)
return train_accuracy
2.3 定义验证函数
def validation(model, rank, world_size, val_loader):
model.eval()
correct = 0
local_rank = int(os.environ['LOCAL_RANK'])
fsdp_loss = torch.zeros(3).to(local_rank)
if rank == 0:
inner_pbar = tqdm.tqdm(
range(len(val_loader)), colour="green", desc="Validation Epoch"
)
with torch.no_grad():
for batch in val_loader:
for key in batch.keys():
batch[key] = batch[key].to(local_rank)
output = model(input_ids=batch["source_ids"],attention_mask=batch["source_mask"],labels=batch["target_ids"])
fsdp_loss[0] += output["loss"].item() # sum up batch loss
fsdp_loss[1] += len(batch)
if rank==0:
inner_pbar.update(1)
dist.all_reduce(fsdp_loss, op=dist.ReduceOp.SUM)
val_loss = fsdp_loss[0] / fsdp_loss[1]
if rank == 0:
inner_pbar.close()
print(f"Validation Loss: {val_loss:.4f}")
return val_loss
2.4 定义一个包装模型到 FSDP 中的分布式训练函数
def fsdp_main(args):
model, tokenizer = setup_model("t5-base")
local_rank = int(os.environ['LOCAL_RANK'])
rank = int(os.environ['RANK'])
world_size = int(os.environ['WORLD_SIZE'])
dataset = load_dataset('wikihow', 'all', data_dir='data/')
print(dataset.keys())
print("Size of train dataset: ", dataset['train'].shape)
print("Size of Validation dataset: ", dataset['validation'].shape)
#wikihow(tokenizer, type_path, num_samples, input_length, output_length, print_text=False)
train_dataset = wikihow(tokenizer, 'train', 1500, 512, 150, False)
val_dataset = wikihow(tokenizer, 'validation', 300, 512, 150, False)
sampler1 = DistributedSampler(train_dataset, rank=rank, num_replicas=world_size, shuffle=True)
sampler2 = DistributedSampler(val_dataset, rank=rank, num_replicas=world_size)
setup()
train_kwargs = {'batch_size': args.batch_size, 'sampler': sampler1}
test_kwargs = {'batch_size': args.test_batch_size, 'sampler': sampler2}
cuda_kwargs = {'num_workers': 2,
'pin_memory': True,
'shuffle': False}
train_kwargs.update(cuda_kwargs)
test_kwargs.update(cuda_kwargs)
train_loader = torch.utils.data.DataLoader(train_dataset,**train_kwargs)
val_loader = torch.utils.data.DataLoader(val_dataset, **test_kwargs)
t5_auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls={
T5Block,
},
)
sharding_strategy: ShardingStrategy = ShardingStrategy.SHARD_GRAD_OP #for Zero2 and FULL_SHARD for Zero3
torch.cuda.set_device(local_rank)
#init_start_event = torch.cuda.Event(enable_timing=True)
#init_end_event = torch.cuda.Event(enable_timing=True)
#init_start_event.record()
bf16_ready = (
torch.version.cuda
and torch.cuda.is_bf16_supported()
and LooseVersion(torch.version.cuda) >= "11.0"
and dist.is_nccl_available()
and nccl.version() >= (2, 10)
)
if bf16_ready:
mp_policy = bfSixteen
else:
mp_policy = None # defaults to fp32
# model is on CPU before input to FSDP
model = FSDP(model,
auto_wrap_policy=t5_auto_wrap_policy,
mixed_precision=mp_policy,
#sharding_strategy=sharding_strategy,
device_id=torch.cuda.current_device())
optimizer = optim.AdamW(model.parameters(), lr=args.lr)
scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
best_val_loss = float("inf")
curr_val_loss = float("inf")
file_save_name = "T5-model-"
if rank == 0:
time_of_run = get_date_of_run()
dur = []
train_acc_tracking = []
val_acc_tracking = []
training_start_time = time.time()
if rank == 0 and args.track_memory:
mem_alloc_tracker = []
mem_reserved_tracker = []
for epoch in range(1, args.epochs + 1):
t0 = time.time()
train_accuracy = train(args, model, rank, world_size, train_loader, optimizer, epoch, sampler=sampler1)
if args.run_validation:
curr_val_loss = validation(model, rank, world_size, val_loader)
scheduler.step()
if rank == 0:
print(f"--> epoch {epoch} completed...entering save and stats zone")
dur.append(time.time() - t0)
train_acc_tracking.append(train_accuracy.item())
if args.run_validation:
val_acc_tracking.append(curr_val_loss.item())
if args.track_memory:
mem_alloc_tracker.append(
format_metrics_to_gb(torch.cuda.memory_allocated())
)
mem_reserved_tracker.append(
format_metrics_to_gb(torch.cuda.memory_reserved())
)
print(f"completed save and stats zone...")
if args.save_model and curr_val_loss < best_val_loss:
# save
if rank == 0:
print(f"--> entering save model state")
save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(
model, StateDictType.FULL_STATE_DICT, save_policy
):
cpu_state = model.state_dict()
#print(f"saving process: rank {rank} done w state_dict")
if rank == 0:
print(f"--> saving model ...")
currEpoch = (
"-" + str(epoch) + "-" + str(round(curr_val_loss.item(), 4)) + ".pt"
)
print(f"--> attempting to save model prefix {currEpoch}")
save_name = file_save_name + "-" + time_of_run + "-" + currEpoch
print(f"--> saving as model name {save_name}")
torch.save(cpu_state, save_name)
if curr_val_loss < best_val_loss:
best_val_loss = curr_val_loss
if rank==0:
print(f"-->>>> New Val Loss Record: {best_val_loss}")
dist.barrier()
cleanup()
2.5 解析参数并设置主函数
if __name__ == '__main__':
# Training settings
parser = argparse.ArgumentParser(description='PyTorch T5 FSDP Example')
parser.add_argument('--batch-size', type=int, default=4, metavar='N',
help='input batch size for training (default: 64)')
parser.add_argument('--test-batch-size', type=int, default=4, metavar='N',
help='input batch size for testing (default: 1000)')
parser.add_argument('--epochs', type=int, default=2, metavar='N',
help='number of epochs to train (default: 3)')
parser.add_argument('--lr', type=float, default=.002, metavar='LR',
help='learning rate (default: .002)')
parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
help='Learning rate step gamma (default: 0.7)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed (default: 1)')
parser.add_argument('--track_memory', action='store_false', default=True,
help='track the gpu memory')
parser.add_argument('--run_validation', action='store_false', default=True,
help='running the validation')
parser.add_argument('--save-model', action='store_false', default=True,
help='For Saving the current Model')
args = parser.parse_args()
torch.manual_seed(args.seed)
fsdp_main(args)
要使用 torchrun 运行训练
torchrun --nnodes 1 --nproc_per_node 4 T5_training.py
Transformer 包装策略¶
如 上一教程 中所述,auto_wrap_policy 是 FSDP 的功能之一,它可以轻松地自动分片给定模型,并将模型、优化器和梯度分片放入不同的 FSDP 单元中。
对于某些架构(例如 Transformer 编码器-解码器),模型的某些部分(例如嵌入表)与编码器和解码器共享。在这种情况下,我们需要将嵌入表放置在外部 FSDP 单元中,以便编码器和解码器都可以访问它。此外,通过注册 Transformer 的层类,可以使分片计划的通信效率更高。在 PyTorch 1.12 中,FSDP 添加了此支持,现在我们有了用于 Transformer 的包装策略。
它可以按如下方式创建,其中 T5Block 表示 T5 Transformer 层类(包含 MHSA 和 FFN)。
t5_auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls={
T5Block,
},
)
torch.cuda.set_device(local_rank)
model = FSDP(model,
auto_wrap_policy=t5_auto_wrap_policy)
要查看包装后的模型,您可以轻松地打印模型并直观地检查分片和 FSDP 单元。
混合精度¶
FSDP 支持灵活的混合精度训练,允许使用任意降低精度的类型(例如 fp16 或 bfloat16)。目前,BFloat16 仅在 Ampere GPU 上可用,因此您需要在使用它之前确认其原生支持。例如,在 V100 上,仍然可以运行 BFloat16,但由于它是非原生运行,因此会导致明显的减速。
要检查 BFloat16 是否原生支持,您可以使用以下方法
bf16_ready = (
torch.version.cuda
and torch.cuda.is_bf16_supported()
and LooseVersion(torch.version.cuda) >= "11.0"
and dist.is_nccl_available()
and nccl.version() >= (2, 10)
)
FSDP 中混合精度的优点之一是提供对参数、梯度和缓冲区的不同精度级别的细粒度控制,如下所示
fpSixteen = MixedPrecision(
param_dtype=torch.float16,
# Gradient communication precision.
reduce_dtype=torch.float16,
# Buffer precision.
buffer_dtype=torch.float16,
)
bfSixteen = MixedPrecision(
param_dtype=torch.bfloat16,
# Gradient communication precision.
reduce_dtype=torch.bfloat16,
# Buffer precision.
buffer_dtype=torch.bfloat16,
)
fp32_policy = MixedPrecision(
param_dtype=torch.float32,
# Gradient communication precision.
reduce_dtype=torch.float32,
# Buffer precision.
buffer_dtype=torch.float32,
)
请注意,如果未指定某种类型(参数、减少、缓冲区),则它们根本不会进行转换。
这种灵活性允许用户进行细粒度控制,例如仅将梯度通信设置为以降低精度进行,并将所有参数/缓冲区计算设置为以全精度进行。这在节点内通信是主要瓶颈并且参数/缓冲区必须以全精度进行以避免精度问题的情况下可能很有用。这可以通过以下策略完成
grad_bf16 = MixedPrecision(reduce_dtype=torch.bfloat16)
在 2.4 中,我们只需将相关的混合精度策略添加到 FSDP 包装器中即可。
model = FSDP(model,
auto_wrap_policy=t5_auto_wrap_policy,
mixed_precision=bfSixteen)
在我们的实验中,我们观察到使用 BFloat16 进行训练可以提高高达 4 倍的速度,并且在某些实验中内存减少了大约 30%,这些内存可用于增加批次大小。
在设备上初始化 FSDP 模型¶
在 1.12 中,FSDP 支持一个 device_id 参数,用于在由 device_id 指定的设备上初始化输入 CPU 模块。当整个模型无法容纳在一个 GPU 上,但可以容纳在主机 CPU 内存中时,这很有用。当指定 device_id 时,FSDP 将根据每个 FSDP 单元的依据将模型移动到指定的设备,避免 GPU OOM 问题,同时初始化速度比基于 CPU 的初始化快得多。
torch.cuda.set_device(local_rank)
model = FSDP(model,
auto_wrap_policy=t5_auto_wrap_policy,
mixed_precision=bfSixteen,
device_id=torch.cuda.current_device())
分片策略¶
默认情况下,FSDP 分片策略设置为完全分片模型参数,梯度和优化器状态在所有 Rank 上分片。(也称为 Zero3 分片)。如果您有兴趣使用 Zero2 分片策略,其中仅分片优化器状态和梯度,则 FSDP 通过使用“ShardingStrategy.SHARD_GRAD_OP”而不是“ShardingStrategy.FULL_SHARD”将分片策略传递给 FSDP 初始化来支持此功能,如下所示
torch.cuda.set_device(local_rank)
model = FSDP(model,
auto_wrap_policy=t5_auto_wrap_policy,
mixed_precision=bfSixteen,
device_id=torch.cuda.current_device(),
sharding_strategy=ShardingStrategy.SHARD_GRAD_OP # ZERO2)
这将减少 FSDP 中的通信开销,在这种情况下,它在正向和反向传递后保存完整参数。
这节省了反向传播期间的 all_gather,因此通信量减少,但内存占用量更大。请注意,完整模型参数在反向传播结束时释放,并且 all_gather 将在下一次正向传递时发生。
反向预取¶
反向预取设置控制何时请求下一个 FSDP 单元的参数的时机。通过将其设置为 BACKWARD_PRE,下一个 FSDP 单元的参数可以开始请求并更快地到达,在当前单元的计算开始之前。这使 all_gather 通信和梯度计算重叠,从而可以提高训练速度,但代价是略微提高内存消耗。它可以在 2.4 中的 FSDP 包装器中使用,如下所示
torch.cuda.set_device(local_rank)
model = FSDP(model,
auto_wrap_policy=t5_auto_wrap_policy,
mixed_precision=bfSixteen,
device_id=torch.cuda.current_device(),
backward_prefetch = BackwardPrefetch.BACKWARD_PRE)
backward_prefetch 有两种模式,BACKWARD_PRE 和 BACKWARD_POST。BACKWARD_POST 表示只有在当前 FSDP 单元处理完成后才会请求下一个 FSDP 单元的参数,从而最大限度地减少内存开销。在某些情况下,使用 BACKWARD_PRE 可以将模型训练速度提高 2-10%,对于更大的模型,速度提升甚至更高。
模型检查点保存,通过流式传输到 Rank0 CPU¶
要使用 FULL_STATE_DICT 保存保存模型检查点,该检查点以与本地模型相同的方式保存模型,PyTorch 1.12 提供了一些实用程序来支持保存更大的模型。
首先,可以指定 FullStateDictConfig,允许仅在 Rank 0 上填充 state_dict 并卸载到 CPU。
使用此配置时,FSDP 将收集模型参数,并将它们逐一卸载到 CPU,仅在 Rank 0 上进行。当最终保存 state_dict 时,它将仅在 Rank 0 上填充并包含 CPU 张量。这避免了对于大于单个 GPU 内存的模型可能出现的 OOM 问题,并允许用户检查点模型的大小大致等于用户机器上的可用 CPU RAM。
此功能可以按如下方式运行
save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(
model, StateDictType.FULL_STATE_DICT, save_policy
):
cpu_state = model.state_dict()
if rank == 0:
save_name = file_save_name + "-" + time_of_run + "-" + currEpoch
torch.save(cpu_state, save_name)
总结¶
在本教程中,我们介绍了 Pytorch 1.12 中可用的许多 FSDP 新功能,并使用 HF T5 作为运行示例。使用正确的包装策略,特别是对于 Transformer 模型,以及混合精度和反向预取,应该可以加快您的训练运行速度。此外,诸如在设备上初始化模型以及通过流式传输到 CPU 进行检查点保存等功能应该有助于避免在处理大型模型时出现 OOM 错误。
我们正在积极努力为下一个版本添加 FSDP 的新功能。如果您有任何反馈、功能请求、问题或在使用 FSDP 时遇到问题,请随时通过在 PyTorch Github 存储库 中打开问题来与我们联系。