• 教程 >
  • 使用 Fully Sharded Data Parallel (FSDP) 进行高级模型训练
快捷方式

使用 Fully Sharded Data Parallel (FSDP) 进行高级模型训练

创建于:2024 年 10 月 31 日 | 最后更新:2024 年 10 月 31 日 | 最后验证:2024 年 11 月 5 日

作者Hamid Shojanazeri, Less Wright, Rohan Varma, Yanli Zhao

你将学到什么
  • PyTorch 的 Fully Sharded Data Parallel 模块:一个用于在

数据并行工作节点之间分片模块参数的包装器。

先决条件
  • PyTorch 1.12 或更高版本

  • 阅读关于 FSDP API 的内容。

本教程介绍了作为 PyTorch 1.12 发布的一部分的 Fully Sharded Data Parallel (FSDP) 更高级的功能。要熟悉 FSDP,请参阅 FSDP 入门教程

在本教程中,我们使用 FSDP 微调 HuggingFace (HF) T5 模型进行文本摘要,作为一个工作示例。

本示例使用 WikiHow 数据集,为简单起见,我们将演示在具有 8 个 A100 GPU 的单节点 P4dn 实例上进行训练。我们现在有多篇博客文章 ((链接 1), (链接 2)) 和一篇关于在多节点集群上进行大规模 FSDP 训练的论文

FSDP 是一个生产就绪的软件包,专注于易用性、性能和长期支持。FSDP 的主要优势之一是减少每个 GPU 的内存占用。这使得可以使用比 DDP 更低的总内存训练更大的模型,并利用计算和通信的重叠来高效地训练模型。这种减轻的内存压力可以用来训练更大的模型或增加批量大小,这可能有助于提高整体训练吞吐量。你可以在此处阅读更多关于 PyTorch FSDP 的信息。

本教程中的 FSDP 特性

  • Transformer 自动包装策略

  • 混合精度

  • 在设备上初始化 FSDP 模型

  • 分片策略

  • 反向预取

  • 通过流式传输到 CPU 保存模型检查点

FSDP 工作原理回顾

从高层次看,FSDP 工作流程如下

在构造函数中

  • 分片模型参数,每个 Rank 只保留自己的分片

在前向传播中

  • 运行 all_gather 收集所有 Rank 的所有分片,以恢复此 FSDP 单元的完整参数,并运行前向计算

  • 丢弃刚刚收集的非自身拥有的参数分片以释放内存

在反向传播中

  • 运行 all_gather 收集所有 Rank 的所有分片,以恢复此 FSDP 单元的完整参数,并运行反向计算

  • 丢弃非自身拥有的参数以释放内存。

  • 运行 reduce_scatter 以同步梯度

微调 HF T5

HF T5 预训练模型有四种不同大小,从 6000 万参数的小型模型到 110 亿参数的 XXL 模型。在本教程中,我们演示了使用 FSDP 微调 T5 3B 模型进行文本摘要,使用 WikiHow 数据集。本教程的主要重点是突出 FSDP 中可用的不同功能,这些功能对于训练超过 30 亿参数的大规模模型非常有用。此外,我们还介绍了基于 Transformer 的模型的特定功能。本教程的代码可在 PyTorch examples 中找到。

设置

1.1 安装最新版 PyTorch

pip3 install torch torchvision torchaudio

1.2 数据集设置

请创建一个 data 文件夹,从 wikihowAll.csvwikihowSep.cs 下载 WikiHow 数据集,并将它们放在 data 文件夹中。我们将使用来自 summarization_dataset 的 wikihow 数据集。

接下来,我们将以下代码片段添加到 Python 脚本“T5_training.py”中。

注意

本教程的完整源代码可在 PyTorch examples 中找到。

1.3 导入必要的包

import os
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from transformers import AutoTokenizer, GPT2TokenizerFast
from transformers import T5Tokenizer, T5ForConditionalGeneration
import functools
from torch.optim.lr_scheduler import StepLR
import torch.nn.functional as F
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
from transformers.models.t5.modeling_t5 import T5Block

from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
 checkpoint_wrapper,
 CheckpointImpl,
 apply_activation_checkpointing_wrapper)

from torch.distributed.fsdp import (
    FullyShardedDataParallel as FSDP,
    MixedPrecision,
    BackwardPrefetch,
    ShardingStrategy,
    FullStateDictConfig,
    StateDictType,
)
from torch.distributed.fsdp.wrap import (
    transformer_auto_wrap_policy,
    enable_wrap,
    wrap,
)
from functools import partial
from torch.utils.data import DataLoader
from pathlib import Path
from summarization_dataset import *
from transformers.models.t5.modeling_t5 import T5Block
from typing import Type
import time
import tqdm
from datetime import datetime

1.4 分布式训练设置。在这里,我们使用两个辅助函数来初始化分布式训练的进程,并在训练完成后进行清理。在本教程中,我们将使用 torch elastic,通过 torchrun,这将自动设置工作节点的 RANKWORLD_SIZE

def setup():
    # initialize the process group
    dist.init_process_group("nccl")

def cleanup():
    dist.destroy_process_group()

2.1 设置 HuggingFace T5 模型

def setup_model(model_name):
    model = T5ForConditionalGeneration.from_pretrained(model_name)
    tokenizer =  T5Tokenizer.from_pretrained(model_name)
    return model, tokenizer

此外,我们在这里添加了几个用于日期和格式化内存指标的辅助函数。

def get_date_of_run():
    """create date and time for file save uniqueness
    example: 2022-05-07-08:31:12_PM'
    """
    date_of_run = datetime.now().strftime("%Y-%m-%d-%I:%M:%S_%p")
    print(f"--> current date and time of run = {date_of_run}")
    return date_of_run

def format_metrics_to_gb(item):
    """quick function to format numbers to gigabyte and round to 4 digit precision"""
    metric_num = item / g_gigabyte
    metric_num = round(metric_num, ndigits=4)
    return metric_num

2.2 定义训练函数

def train(args, model, rank, world_size, train_loader, optimizer, epoch, sampler=None):
    model.train()
    local_rank = int(os.environ['LOCAL_RANK'])
    fsdp_loss = torch.zeros(2).to(local_rank)

    if sampler:
        sampler.set_epoch(epoch)
    if rank==0:
        inner_pbar = tqdm.tqdm(
            range(len(train_loader)), colour="blue", desc="r0 Training Epoch"
        )
    for batch in train_loader:
        for key in batch.keys():
            batch[key] = batch[key].to(local_rank)
        optimizer.zero_grad()
        output = model(input_ids=batch["source_ids"],attention_mask=batch["source_mask"],labels=batch["target_ids"] )
        loss = output["loss"]
        loss.backward()
        optimizer.step()
        fsdp_loss[0] += loss.item()
        fsdp_loss[1] += len(batch)
        if rank==0:
            inner_pbar.update(1)

    dist.all_reduce(fsdp_loss, op=dist.ReduceOp.SUM)
    train_accuracy = fsdp_loss[0] / fsdp_loss[1]


    if rank == 0:
        inner_pbar.close()
        print(
                f"Train Epoch: \t{epoch}, Loss: \t{train_accuracy:.4f}"
            )
    return train_accuracy

2.3 定义验证函数

def validation(model, rank, world_size, val_loader):
    model.eval()
    correct = 0
    local_rank = int(os.environ['LOCAL_RANK'])
    fsdp_loss = torch.zeros(3).to(local_rank)
    if rank == 0:
        inner_pbar = tqdm.tqdm(
            range(len(val_loader)), colour="green", desc="Validation Epoch"
        )
    with torch.no_grad():
        for batch in val_loader:
            for key in batch.keys():
                batch[key] = batch[key].to(local_rank)
            output = model(input_ids=batch["source_ids"],attention_mask=batch["source_mask"],labels=batch["target_ids"])
            fsdp_loss[0] += output["loss"].item()  # sum up batch loss
            fsdp_loss[1] += len(batch)

            if rank==0:
                inner_pbar.update(1)

    dist.all_reduce(fsdp_loss, op=dist.ReduceOp.SUM)
    val_loss = fsdp_loss[0] / fsdp_loss[1]
    if rank == 0:
        inner_pbar.close()
        print(f"Validation Loss: {val_loss:.4f}")
    return val_loss

2.4 定义一个将模型封装在 FSDP 中的分布式训练函数

def fsdp_main(args):

    model, tokenizer = setup_model("t5-base")

    local_rank = int(os.environ['LOCAL_RANK'])
    rank = int(os.environ['RANK'])
    world_size = int(os.environ['WORLD_SIZE'])


    dataset = load_dataset('wikihow', 'all', data_dir='data/')
    print(dataset.keys())
    print("Size of train dataset: ", dataset['train'].shape)
    print("Size of Validation dataset: ", dataset['validation'].shape)


    #wikihow(tokenizer, type_path, num_samples, input_length, output_length, print_text=False)
    train_dataset = wikihow(tokenizer, 'train', 1500, 512, 150, False)
    val_dataset = wikihow(tokenizer, 'validation', 300, 512, 150, False)

    sampler1 = DistributedSampler(train_dataset, rank=rank, num_replicas=world_size, shuffle=True)
    sampler2 = DistributedSampler(val_dataset, rank=rank, num_replicas=world_size)

    setup()


    train_kwargs = {'batch_size': args.batch_size, 'sampler': sampler1}
    test_kwargs = {'batch_size': args.test_batch_size, 'sampler': sampler2}
    cuda_kwargs = {'num_workers': 2,
                    'pin_memory': True,
                    'shuffle': False}
    train_kwargs.update(cuda_kwargs)
    test_kwargs.update(cuda_kwargs)

    train_loader = torch.utils.data.DataLoader(train_dataset,**train_kwargs)
    val_loader = torch.utils.data.DataLoader(val_dataset, **test_kwargs)

    t5_auto_wrap_policy = functools.partial(
        transformer_auto_wrap_policy,
        transformer_layer_cls={
            T5Block,
        },
    )
    sharding_strategy: ShardingStrategy = ShardingStrategy.SHARD_GRAD_OP #for Zero2 and FULL_SHARD for Zero3
    torch.cuda.set_device(local_rank)


    #init_start_event = torch.cuda.Event(enable_timing=True)
    #init_end_event = torch.cuda.Event(enable_timing=True)

    #init_start_event.record()

    bf16_ready = (
    torch.version.cuda
    and torch.cuda.is_bf16_supported()
    and LooseVersion(torch.version.cuda) >= "11.0"
    and dist.is_nccl_available()
    and nccl.version() >= (2, 10)
    )

    if bf16_ready:
        mp_policy = bfSixteen
    else:
        mp_policy = None # defaults to fp32

    # model is on CPU before input to FSDP
    model = FSDP(model,
        auto_wrap_policy=t5_auto_wrap_policy,
        mixed_precision=mp_policy,
        #sharding_strategy=sharding_strategy,
        device_id=torch.cuda.current_device())

    optimizer = optim.AdamW(model.parameters(), lr=args.lr)

    scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
    best_val_loss = float("inf")
    curr_val_loss = float("inf")
    file_save_name = "T5-model-"

    if rank == 0:
        time_of_run = get_date_of_run()
        dur = []
        train_acc_tracking = []
        val_acc_tracking = []
        training_start_time = time.time()

    if rank == 0 and args.track_memory:
        mem_alloc_tracker = []
        mem_reserved_tracker = []

    for epoch in range(1, args.epochs + 1):
        t0 = time.time()
        train_accuracy = train(args, model, rank, world_size, train_loader, optimizer, epoch, sampler=sampler1)
        if args.run_validation:
            curr_val_loss = validation(model, rank, world_size, val_loader)
        scheduler.step()

        if rank == 0:

            print(f"--> epoch {epoch} completed...entering save and stats zone")

            dur.append(time.time() - t0)
            train_acc_tracking.append(train_accuracy.item())

            if args.run_validation:
                val_acc_tracking.append(curr_val_loss.item())

            if args.track_memory:
                mem_alloc_tracker.append(
                    format_metrics_to_gb(torch.cuda.memory_allocated())
                )
                mem_reserved_tracker.append(
                    format_metrics_to_gb(torch.cuda.memory_reserved())
                )
            print(f"completed save and stats zone...")

        if args.save_model and curr_val_loss < best_val_loss:

            # save
            if rank == 0:
                print(f"--> entering save model state")

            save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
            with FSDP.state_dict_type(
                model, StateDictType.FULL_STATE_DICT, save_policy
            ):
                cpu_state = model.state_dict()
            #print(f"saving process: rank {rank}  done w state_dict")


            if rank == 0:
                print(f"--> saving model ...")
                currEpoch = (
                    "-" + str(epoch) + "-" + str(round(curr_val_loss.item(), 4)) + ".pt"
                )
                print(f"--> attempting to save model prefix {currEpoch}")
                save_name = file_save_name + "-" + time_of_run + "-" + currEpoch
                print(f"--> saving as model name {save_name}")

                torch.save(cpu_state, save_name)

        if curr_val_loss < best_val_loss:

            best_val_loss = curr_val_loss
            if rank==0:
                print(f"-->>>> New Val Loss Record: {best_val_loss}")

    dist.barrier()
    cleanup()

2.5 解析参数并设置主函数

if __name__ == '__main__':
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch T5 FSDP Example')
    parser.add_argument('--batch-size', type=int, default=4, metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size', type=int, default=4, metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs', type=int, default=2, metavar='N',
                        help='number of epochs to train (default: 3)')
    parser.add_argument('--lr', type=float, default=.002, metavar='LR',
                        help='learning rate (default: .002)')
    parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
                        help='Learning rate step gamma (default: 0.7)')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed', type=int, default=1, metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--track_memory', action='store_false', default=True,
                        help='track the gpu memory')
    parser.add_argument('--run_validation', action='store_false', default=True,
                        help='running the validation')
    parser.add_argument('--save-model', action='store_false', default=True,
                        help='For Saving the current Model')
    args = parser.parse_args()

    torch.manual_seed(args.seed)

    fsdp_main(args)

要使用 torchrun 运行训练

torchrun --nnodes 1 --nproc_per_node 4  T5_training.py

Transformer 包装策略

之前的教程所述,auto_wrap_policy 是 FSDP 的特性之一,它使得自动分片给定的模型并将模型、优化器和梯度分片放入不同的 FSDP 单元变得容易。

对于某些架构,例如 Transformer 编码器-解码器,模型的某些部分(例如嵌入表)与编码器和解码器共享。在这种情况下,我们需要将嵌入表放在外部 FSDP 单元中,以便编码器和解码器都可以访问它。此外,通过注册 Transformer 的层类,可以使分片计划更具通信效率。在 PyTorch 1.12 中,FSDP 添加了此支持,现在我们有了一个用于 Transformer 的包装策略。

可以按如下方式创建,其中 T5Block 表示 T5 Transformer 层类(包含 MHSA 和 FFN)。

t5_auto_wrap_policy = functools.partial(
        transformer_auto_wrap_policy,
        transformer_layer_cls={
            T5Block,
        },
    )
torch.cuda.set_device(local_rank)


model = FSDP(model,
    auto_wrap_policy=t5_auto_wrap_policy)

要查看包装后的模型,你可以轻松打印模型并直观检查分片和 FSDP 单元。

混合精度

FSDP 支持灵活的混合精度训练,允许任意的降低精度类型(例如 fp16 或 bfloat16)。目前 BFloat16 仅在 Ampere GPU 上可用,因此在使用之前需要确认原生支持。例如,在 V100 上仍然可以运行 BFloat16,但由于它是非原生运行的,可能会导致明显的性能下降。

要检查 BFloat16 是否原生支持,你可以使用以下方法

bf16_ready = (
    torch.version.cuda
    and torch.cuda.is_bf16_supported()
    and LooseVersion(torch.version.cuda) >= "11.0"
    and dist.is_nccl_available()
    and nccl.version() >= (2, 10)
)

FSDP 中混合精度的一个优点是提供对参数、梯度和缓冲区的不同精度级别的细粒度控制,如下所示

fpSixteen = MixedPrecision(
    param_dtype=torch.float16,
    # Gradient communication precision.
    reduce_dtype=torch.float16,
    # Buffer precision.
    buffer_dtype=torch.float16,
)

bfSixteen = MixedPrecision(
    param_dtype=torch.bfloat16,
    # Gradient communication precision.
    reduce_dtype=torch.bfloat16,
    # Buffer precision.
    buffer_dtype=torch.bfloat16,
)

fp32_policy = MixedPrecision(
    param_dtype=torch.float32,
    # Gradient communication precision.
    reduce_dtype=torch.float32,
    # Buffer precision.
    buffer_dtype=torch.float32,
)

请注意,如果未指定某种类型(参数、reduce、缓冲区),则不会进行任何类型转换。

这种灵活性允许用户进行细粒度控制,例如只设置梯度通信在降低精度下进行,而所有参数/缓冲区计算都在全精度下进行。这在节点内通信是主要瓶颈且参数/缓冲区必须是全精度以避免精度问题的情况下可能很有用。可以通过以下策略实现:

grad_bf16 = MixedPrecision(reduce_dtype=torch.bfloat16)

在 2.4 中,我们只需将相关的混合精度策略添加到 FSDP 包装器中

model = FSDP(model,
       auto_wrap_policy=t5_auto_wrap_policy,
       mixed_precision=bfSixteen)

在我们的实验中,我们观察到使用 BFloat16 进行训练可以将速度提高多达 4 倍,并在某些实验中将内存减少约 30%,这可用于增加批量大小。

在设备上初始化 FSDP 模型

在 1.12 版本中,FSDP 支持一个 device_id 参数,用于在 device_id 指定的设备上初始化输入的 CPU 模块。这在整个模型无法容纳在单个 GPU 上,但可以容纳在主机 CPU 内存中的情况下非常有用。指定 device_id 后,FSDP 会按每个 FSDP 单元将模型移动到指定的设备,从而避免 GPU OOM 问题,同时初始化速度比基于 CPU 的初始化快几倍。

torch.cuda.set_device(local_rank)

 model = FSDP(model,
        auto_wrap_policy=t5_auto_wrap_policy,
        mixed_precision=bfSixteen,
        device_id=torch.cuda.current_device())

分片策略

FSDP 分片策略默认设置为完全分片模型参数、梯度和优化器状态,将它们分片到所有 Rank。(也称为 Zero3 分片)。如果你有兴趣使用 Zero2 分片策略(仅对优化器状态和梯度进行分片),FSDP 支持此功能,方法是在 FSDP 初始化时传递分片策略,使用“ShardingStrategy.SHARD_GRAD_OP”而不是“ShardingStrategy.FULL_SHARD”,如下所示

torch.cuda.set_device(local_rank)

 model = FSDP(model,
        auto_wrap_policy=t5_auto_wrap_policy,
        mixed_precision=bfSixteen,
        device_id=torch.cuda.current_device(),
        sharding_strategy=ShardingStrategy.SHARD_GRAD_OP # ZERO2)

这将减少 FSDP 中的通信开销,在这种情况下,它在前向和反向传播后会保留完整的参数。

这在反向传播期间节省了一次 all_gather 操作,因此通信量减少,但代价是内存占用较高。请注意,完整的模型参数在反向传播结束时会释放,并在下一次前向传播时发生 all_gather。

反向预取

反向预取设置控制何时请求下一个 FSDP 单元参数的时机。通过将其设置为 BACKWARD_PRE,可以在当前单元计算开始之前更早地请求并获取下一个 FSDP 单元的参数。这使得 all_gather 通信与梯度计算重叠,从而提高训练速度,代价是略微增加内存消耗。可以在 2.4 中的 FSDP 包装器中使用它,如下所示

torch.cuda.set_device(local_rank)

 model = FSDP(model,
        auto_wrap_policy=t5_auto_wrap_policy,
        mixed_precision=bfSixteen,
        device_id=torch.cuda.current_device(),
        backward_prefetch = BackwardPrefetch.BACKWARD_PRE)

backward_prefetch 有两种模式:BACKWARD_PREBACKWARD_POSTBACKWARD_POST 意味着在当前 FSDP 单元处理完成之前,不会请求下一个 FSDP 单元的参数,从而最大限度地减少内存开销。在某些情况下,使用 BACKWARD_PRE 可以将模型训练速度提高 2-10%,对于较大的模型甚至可以观察到更高的速度提升。

通过流式传输到 Rank0 CPU 保存模型检查点

为了使用 FULL_STATE_DICT 保存模型检查点(这种方式与保存本地模型类似),PyTorch 1.12 提供了一些工具来支持保存较大的模型。

首先,可以指定一个 FullStateDictConfig,允许仅在 Rank 0 上填充 state_dict 并卸载到 CPU。

使用此配置时,FSDP 将 allgather 模型参数,仅在 Rank 0 上将它们逐个卸载到 CPU。当最终保存 state_dict 时,它将仅在 Rank 0 上填充并包含 CPU 张量。这避免了对于大于单个 GPU 内存的模型可能出现的 OOM 问题,并允许用户保存大小大致与用户机器上可用 CPU RAM 相当的模型检查点。

此功能可以按如下方式运行

save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(
            model, StateDictType.FULL_STATE_DICT, save_policy
        ):
            cpu_state = model.state_dict()
if rank == 0:
 save_name = file_save_name + "-" + time_of_run + "-" + currEpoch
 torch.save(cpu_state, save_name)

总结

在本教程中,我们介绍了 PyTorch 1.12 中 FSDP 的许多新特性,并使用 HF T5 作为运行示例。使用适当的包装策略(特别是对于 Transformer 模型),以及混合精度和反向预取,应该可以加速你的训练运行。此外,诸如在设备上初始化模型和通过流式传输到 CPU 保存检查点等特性应该有助于避免处理大型模型时的 OOM 错误。

我们正在积极努力为 FSDP 的下一个版本添加新功能。如果你有反馈、功能请求、问题或在使用 FSDP 时遇到问题,请随时通过在 PyTorch Github 仓库中提出 issue 来联系我们。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得问题解答

查看资源