• 教程 >
  • 使用全分片数据并行 (FSDP) 的高级模型训练
快捷方式

使用全分片数据并行 (FSDP) 的高级模型训练

创建于:2024 年 10 月 31 日 | 最后更新:2024 年 10 月 31 日 | 最后验证:2024 年 11 月 05 日

作者: Hamid Shojanazeri, Less Wright, Rohan Varma, Yanli Zhao

您将学到什么
  • PyTorch 的全分片数据并行模块:一个用于在数据并行工作进程之间分片模块参数的包装器。

数据并行工作进程。

先决条件
  • PyTorch 1.12 或更高版本

  • 阅读关于 FSDP API 的信息。

本教程介绍了 PyTorch 1.12 版本中全分片数据并行 (FSDP) 的更高级功能。要熟悉 FSDP,请参考 FSDP 入门教程

在本教程中,我们将使用 FSDP 微调 HuggingFace (HF) T5 模型,以进行文本摘要作为工作示例。

该示例使用 WikiHow,为了简单起见,我们将展示在具有 8 个 A100 GPU 的单节点 P4dn 实例上的训练。我们现在有几篇博客文章 ( (链接 1), (链接 2)) 和一篇关于多节点集群上大规模 FSDP 训练的论文

FSDP 是一个生产就绪的软件包,专注于易用性、性能和长期支持。FSDP 的主要优势之一是减少每个 GPU 上的内存占用。与 DDP 相比,这使得可以使用更少的总内存来训练更大的模型,并利用计算和通信的重叠来高效地训练模型。这种降低的内存压力可以用于训练更大的模型或增加批量大小,从而可能有助于提高整体训练吞吐量。您可以在此处阅读更多关于 PyTorch FSDP 的信息。

本教程中的 FSDP 功能

  • Transformer 自动包装策略

  • 混合精度

  • 在设备上初始化 FSDP 模型

  • 分片策略

  • 反向预取

  • 通过流式传输到 CPU 保存模型检查点

FSDP 工作原理回顾

在高层次上,FDSP 的工作方式如下

在构造函数中

  • 分片模型参数,每个 rank 仅保留自己的分片

在前向传播中

  • 运行 all_gather 以收集来自所有 rank 的所有分片,以恢复此 FSDP 单元的完整参数并运行前向计算

  • 丢弃刚刚收集的非自有参数分片以释放内存

在反向传播中

  • 运行 all_gather 以收集来自所有 rank 的所有分片,以恢复此 FSDP 单元中的完整参数并运行反向计算

  • 丢弃非自有参数以释放内存。

  • 运行 reduce_scatter 以同步梯度

微调 HF T5

HF T5 预训练模型有四种不同尺寸,从小型的 6 千万参数到 XXL 的 110 亿参数不等。在本教程中,我们演示了使用 FSDP 微调 T5 3B 以进行文本摘要,使用 WikiHow 数据集。本教程的主要重点是突出 FSDP 中可用于训练超过 30 亿参数的大规模模型的不同可用功能。此外,我们还介绍了基于 Transformer 模型的特定功能。本教程的代码可在 Pytorch 示例中找到。

设置

1.1 安装最新的 PyTorch

pip3 install torch torchvision torchaudio

1.2 数据集设置

请创建一个 data 文件夹,从 wikihowAll.csvwikihowSep.cs 下载 WikiHow 数据集,并将它们放在 data 文件夹中。我们将使用来自 summarization_dataset 的 wikihow 数据集。

接下来,我们将以下代码片段添加到 Python 脚本 “T5_training.py” 中。

注意

本教程的完整源代码可在 PyTorch 示例中找到。

1.3 导入必要的包

import os
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from transformers import AutoTokenizer, GPT2TokenizerFast
from transformers import T5Tokenizer, T5ForConditionalGeneration
import functools
from torch.optim.lr_scheduler import StepLR
import torch.nn.functional as F
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
from transformers.models.t5.modeling_t5 import T5Block

from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
 checkpoint_wrapper,
 CheckpointImpl,
 apply_activation_checkpointing_wrapper)

from torch.distributed.fsdp import (
    FullyShardedDataParallel as FSDP,
    MixedPrecision,
    BackwardPrefetch,
    ShardingStrategy,
    FullStateDictConfig,
    StateDictType,
)
from torch.distributed.fsdp.wrap import (
    transformer_auto_wrap_policy,
    enable_wrap,
    wrap,
)
from functools import partial
from torch.utils.data import DataLoader
from pathlib import Path
from summarization_dataset import *
from transformers.models.t5.modeling_t5 import T5Block
from typing import Type
import time
import tqdm
from datetime import datetime

1.4 分布式训练设置。这里我们使用两个辅助函数来初始化分布式训练的进程,然后在训练完成后进行清理。在本教程中,我们将使用 torch elastic,使用 torchrun ,它将自动设置 worker 的 RANKWORLD_SIZE

def setup():
    # initialize the process group
    dist.init_process_group("nccl")

def cleanup():
    dist.destroy_process_group()

2.1 设置 HuggingFace T5 模型

def setup_model(model_name):
    model = T5ForConditionalGeneration.from_pretrained(model_name)
    tokenizer =  T5Tokenizer.from_pretrained(model_name)
    return model, tokenizer

我们还在这里添加了几个辅助函数,用于日期和格式化内存指标。

def get_date_of_run():
    """create date and time for file save uniqueness
    example: 2022-05-07-08:31:12_PM'
    """
    date_of_run = datetime.now().strftime("%Y-%m-%d-%I:%M:%S_%p")
    print(f"--> current date and time of run = {date_of_run}")
    return date_of_run

def format_metrics_to_gb(item):
    """quick function to format numbers to gigabyte and round to 4 digit precision"""
    metric_num = item / g_gigabyte
    metric_num = round(metric_num, ndigits=4)
    return metric_num

2.2 定义训练函数

def train(args, model, rank, world_size, train_loader, optimizer, epoch, sampler=None):
    model.train()
    local_rank = int(os.environ['LOCAL_RANK'])
    fsdp_loss = torch.zeros(2).to(local_rank)

    if sampler:
        sampler.set_epoch(epoch)
    if rank==0:
        inner_pbar = tqdm.tqdm(
            range(len(train_loader)), colour="blue", desc="r0 Training Epoch"
        )
    for batch in train_loader:
        for key in batch.keys():
            batch[key] = batch[key].to(local_rank)
        optimizer.zero_grad()
        output = model(input_ids=batch["source_ids"],attention_mask=batch["source_mask"],labels=batch["target_ids"] )
        loss = output["loss"]
        loss.backward()
        optimizer.step()
        fsdp_loss[0] += loss.item()
        fsdp_loss[1] += len(batch)
        if rank==0:
            inner_pbar.update(1)

    dist.all_reduce(fsdp_loss, op=dist.ReduceOp.SUM)
    train_accuracy = fsdp_loss[0] / fsdp_loss[1]


    if rank == 0:
        inner_pbar.close()
        print(
                f"Train Epoch: \t{epoch}, Loss: \t{train_accuracy:.4f}"
            )
    return train_accuracy

2.3 定义验证函数

def validation(model, rank, world_size, val_loader):
    model.eval()
    correct = 0
    local_rank = int(os.environ['LOCAL_RANK'])
    fsdp_loss = torch.zeros(3).to(local_rank)
    if rank == 0:
        inner_pbar = tqdm.tqdm(
            range(len(val_loader)), colour="green", desc="Validation Epoch"
        )
    with torch.no_grad():
        for batch in val_loader:
            for key in batch.keys():
                batch[key] = batch[key].to(local_rank)
            output = model(input_ids=batch["source_ids"],attention_mask=batch["source_mask"],labels=batch["target_ids"])
            fsdp_loss[0] += output["loss"].item()  # sum up batch loss
            fsdp_loss[1] += len(batch)

            if rank==0:
                inner_pbar.update(1)

    dist.all_reduce(fsdp_loss, op=dist.ReduceOp.SUM)
    val_loss = fsdp_loss[0] / fsdp_loss[1]
    if rank == 0:
        inner_pbar.close()
        print(f"Validation Loss: {val_loss:.4f}")
    return val_loss

2.4 定义一个分布式训练函数,该函数将模型包装在 FSDP 中

def fsdp_main(args):

    model, tokenizer = setup_model("t5-base")

    local_rank = int(os.environ['LOCAL_RANK'])
    rank = int(os.environ['RANK'])
    world_size = int(os.environ['WORLD_SIZE'])


    dataset = load_dataset('wikihow', 'all', data_dir='data/')
    print(dataset.keys())
    print("Size of train dataset: ", dataset['train'].shape)
    print("Size of Validation dataset: ", dataset['validation'].shape)


    #wikihow(tokenizer, type_path, num_samples, input_length, output_length, print_text=False)
    train_dataset = wikihow(tokenizer, 'train', 1500, 512, 150, False)
    val_dataset = wikihow(tokenizer, 'validation', 300, 512, 150, False)

    sampler1 = DistributedSampler(train_dataset, rank=rank, num_replicas=world_size, shuffle=True)
    sampler2 = DistributedSampler(val_dataset, rank=rank, num_replicas=world_size)

    setup()


    train_kwargs = {'batch_size': args.batch_size, 'sampler': sampler1}
    test_kwargs = {'batch_size': args.test_batch_size, 'sampler': sampler2}
    cuda_kwargs = {'num_workers': 2,
                    'pin_memory': True,
                    'shuffle': False}
    train_kwargs.update(cuda_kwargs)
    test_kwargs.update(cuda_kwargs)

    train_loader = torch.utils.data.DataLoader(train_dataset,**train_kwargs)
    val_loader = torch.utils.data.DataLoader(val_dataset, **test_kwargs)

    t5_auto_wrap_policy = functools.partial(
        transformer_auto_wrap_policy,
        transformer_layer_cls={
            T5Block,
        },
    )
    sharding_strategy: ShardingStrategy = ShardingStrategy.SHARD_GRAD_OP #for Zero2 and FULL_SHARD for Zero3
    torch.cuda.set_device(local_rank)


    #init_start_event = torch.cuda.Event(enable_timing=True)
    #init_end_event = torch.cuda.Event(enable_timing=True)

    #init_start_event.record()

    bf16_ready = (
    torch.version.cuda
    and torch.cuda.is_bf16_supported()
    and LooseVersion(torch.version.cuda) >= "11.0"
    and dist.is_nccl_available()
    and nccl.version() >= (2, 10)
    )

    if bf16_ready:
        mp_policy = bfSixteen
    else:
        mp_policy = None # defaults to fp32

    # model is on CPU before input to FSDP
    model = FSDP(model,
        auto_wrap_policy=t5_auto_wrap_policy,
        mixed_precision=mp_policy,
        #sharding_strategy=sharding_strategy,
        device_id=torch.cuda.current_device())

    optimizer = optim.AdamW(model.parameters(), lr=args.lr)

    scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
    best_val_loss = float("inf")
    curr_val_loss = float("inf")
    file_save_name = "T5-model-"

    if rank == 0:
        time_of_run = get_date_of_run()
        dur = []
        train_acc_tracking = []
        val_acc_tracking = []
        training_start_time = time.time()

    if rank == 0 and args.track_memory:
        mem_alloc_tracker = []
        mem_reserved_tracker = []

    for epoch in range(1, args.epochs + 1):
        t0 = time.time()
        train_accuracy = train(args, model, rank, world_size, train_loader, optimizer, epoch, sampler=sampler1)
        if args.run_validation:
            curr_val_loss = validation(model, rank, world_size, val_loader)
        scheduler.step()

        if rank == 0:

            print(f"--> epoch {epoch} completed...entering save and stats zone")

            dur.append(time.time() - t0)
            train_acc_tracking.append(train_accuracy.item())

            if args.run_validation:
                val_acc_tracking.append(curr_val_loss.item())

            if args.track_memory:
                mem_alloc_tracker.append(
                    format_metrics_to_gb(torch.cuda.memory_allocated())
                )
                mem_reserved_tracker.append(
                    format_metrics_to_gb(torch.cuda.memory_reserved())
                )
            print(f"completed save and stats zone...")

        if args.save_model and curr_val_loss < best_val_loss:

            # save
            if rank == 0:
                print(f"--> entering save model state")

            save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
            with FSDP.state_dict_type(
                model, StateDictType.FULL_STATE_DICT, save_policy
            ):
                cpu_state = model.state_dict()
            #print(f"saving process: rank {rank}  done w state_dict")


            if rank == 0:
                print(f"--> saving model ...")
                currEpoch = (
                    "-" + str(epoch) + "-" + str(round(curr_val_loss.item(), 4)) + ".pt"
                )
                print(f"--> attempting to save model prefix {currEpoch}")
                save_name = file_save_name + "-" + time_of_run + "-" + currEpoch
                print(f"--> saving as model name {save_name}")

                torch.save(cpu_state, save_name)

        if curr_val_loss < best_val_loss:

            best_val_loss = curr_val_loss
            if rank==0:
                print(f"-->>>> New Val Loss Record: {best_val_loss}")

    dist.barrier()
    cleanup()

2.5 解析参数并设置主函数

if __name__ == '__main__':
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch T5 FSDP Example')
    parser.add_argument('--batch-size', type=int, default=4, metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size', type=int, default=4, metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs', type=int, default=2, metavar='N',
                        help='number of epochs to train (default: 3)')
    parser.add_argument('--lr', type=float, default=.002, metavar='LR',
                        help='learning rate (default: .002)')
    parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
                        help='Learning rate step gamma (default: 0.7)')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed', type=int, default=1, metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--track_memory', action='store_false', default=True,
                        help='track the gpu memory')
    parser.add_argument('--run_validation', action='store_false', default=True,
                        help='running the validation')
    parser.add_argument('--save-model', action='store_false', default=True,
                        help='For Saving the current Model')
    args = parser.parse_args()

    torch.manual_seed(args.seed)

    fsdp_main(args)

要使用 torchrun 运行训练

torchrun --nnodes 1 --nproc_per_node 4  T5_training.py

Transformer 包装策略

正如在 之前的教程 中讨论的那样,auto_wrap_policy 是 FSDP 的功能之一,它使自动分片给定模型并将模型、优化器和梯度分片放入不同的 FSDP 单元变得容易。

对于某些架构(如 Transformer 编码器-解码器),模型的某些部分(如嵌入表)与编码器和解码器共享。在这种情况下,我们需要将嵌入表放在外部 FSDP 单元中,以便可以从编码器和解码器访问它。此外,通过注册 Transformer 的层类,可以使分片计划的通信效率更高。在 PyTorch 1.12 中,FSDP 添加了此支持,现在我们有了 Transformer 的包装策略。

它可以按如下方式创建,其中 T5Block 代表 T5 Transformer 层类(包含 MHSA 和 FFN)。

t5_auto_wrap_policy = functools.partial(
        transformer_auto_wrap_policy,
        transformer_layer_cls={
            T5Block,
        },
    )
torch.cuda.set_device(local_rank)


model = FSDP(model,
    auto_wrap_policy=t5_auto_wrap_policy)

要查看包装后的模型,您可以轻松打印模型并直观地检查分片和 FSDP 单元。

混合精度

FSDP 支持灵活的混合精度训练,允许任意降低的精度类型(如 fp16 或 bfloat16)。目前,BFloat16 仅在 Ampere GPU 上可用,因此您需要在使用前确认本机支持。例如,在 V100 上,BFloat16 仍然可以运行,但由于它以非原生方式运行,可能会导致明显的减速。

要检查是否本机支持 BFloat16,您可以使用以下命令

bf16_ready = (
    torch.version.cuda
    and torch.cuda.is_bf16_supported()
    and LooseVersion(torch.version.cuda) >= "11.0"
    and dist.is_nccl_available()
    and nccl.version() >= (2, 10)
)

FSDP 中混合精度的优势之一是为参数、梯度和缓冲区提供对不同精度级别的精细控制,如下所示

fpSixteen = MixedPrecision(
    param_dtype=torch.float16,
    # Gradient communication precision.
    reduce_dtype=torch.float16,
    # Buffer precision.
    buffer_dtype=torch.float16,
)

bfSixteen = MixedPrecision(
    param_dtype=torch.bfloat16,
    # Gradient communication precision.
    reduce_dtype=torch.bfloat16,
    # Buffer precision.
    buffer_dtype=torch.bfloat16,
)

fp32_policy = MixedPrecision(
    param_dtype=torch.float32,
    # Gradient communication precision.
    reduce_dtype=torch.float32,
    # Buffer precision.
    buffer_dtype=torch.float32,
)

请注意,如果未指定特定类型(参数、reduce、缓冲区),则根本不会进行转换。

这种灵活性允许用户进行细粒度的控制,例如仅将梯度通信设置为以降低的精度进行,而所有参数/缓冲区计算都以全精度完成。这在节点内通信是主要瓶颈,并且参数/缓冲区必须以全精度才能避免精度问题的情况下可能很有用。这可以通过以下策略完成

grad_bf16 = MixedPrecision(reduce_dtype=torch.bfloat16)

在 2.4 中,我们只需将相关的混合精度策略添加到 FSDP 包装器中

model = FSDP(model,
       auto_wrap_policy=t5_auto_wrap_policy,
       mixed_precision=bfSixteen)

在我们的实验中,我们观察到使用 BFloat16 进行训练的速度提高了 4 倍,并且在某些实验中内存减少了约 30%,这可用于增加批量大小。

在设备上初始化 FSDP 模型

在 1.12 中,FSDP 支持 device_id 参数,旨在在 device_id 给定的设备上初始化输入 CPU 模块。当整个模型不适合单个 GPU,但适合主机的 CPU 内存时,这非常有用。当指定 device_id 时,FSDP 会将模型移动到每个 FSDP 单元指定的设备上,避免 GPU OOM 问题,同时初始化速度比基于 CPU 的初始化快几倍

torch.cuda.set_device(local_rank)

 model = FSDP(model,
        auto_wrap_policy=t5_auto_wrap_policy,
        mixed_precision=bfSixteen,
        device_id=torch.cuda.current_device())

分片策略

默认情况下,FSDP 分片策略设置为完全分片模型参数,梯度和优化器状态在所有 rank 之间分片。(也称为 Zero3 分片)。如果您有兴趣使用 Zero2 分片策略,其中仅分片优化器状态和梯度,FSDP 通过使用 “ShardingStrategy.SHARD_GRAD_OP” 而不是 “ShardingStrategy.FULL_SHARD” 将分片策略传递给 FSDP 初始化来支持此功能,如下所示

torch.cuda.set_device(local_rank)

 model = FSDP(model,
        auto_wrap_policy=t5_auto_wrap_policy,
        mixed_precision=bfSixteen,
        device_id=torch.cuda.current_device(),
        sharding_strategy=ShardingStrategy.SHARD_GRAD_OP # ZERO2)

这将减少 FSDP 中的通信开销,在这种情况下,它在前向传播和反向传播期间都保持完整的参数。

这节省了反向传播期间的 all_gather,因此通信量减少了,但内存占用量更高。请注意,完整模型参数在反向传播结束时释放,all_gather 将在下一个前向传播时发生。

反向预取

反向预取设置控制何时应请求下一个 FSDP 单元的参数。通过将其设置为 BACKWARD_PRE,可以在当前单元的计算开始之前更早地开始请求和到达下一个 FSDP 单元的参数。这重叠了 all_gather 通信和梯度计算,可以提高训练速度,但会略微增加内存消耗。它可以在 2.4 中的 FSDP 包装器中如下使用

torch.cuda.set_device(local_rank)

 model = FSDP(model,
        auto_wrap_policy=t5_auto_wrap_policy,
        mixed_precision=bfSixteen,
        device_id=torch.cuda.current_device(),
        backward_prefetch = BackwardPrefetch.BACKWARD_PRE)

backward_prefetch 有两种模式,BACKWARD_PREBACKWARD_POSTBACKWARD_POST 意味着在当前 FSDP 单元处理完成之前,不会请求下一个 FSDP 单元的参数,从而最大限度地减少内存开销。在某些情况下,使用 BACKWARD_PRE 可以将模型训练速度提高 2-10%,对于更大的模型,速度提升甚至更高。

模型检查点保存,通过流式传输到 Rank0 CPU

要使用 FULL_STATE_DICT 保存模型检查点,这以与本地模型相同的方式保存模型,PyTorch 1.12 提供了几个实用程序来支持保存更大的模型。

首先,可以指定 FullStateDictConfig,允许仅在 rank 0 上填充 state_dict 并卸载到 CPU。

使用此配置时,FSDP 将 allgather 模型参数,一次一个地将它们卸载到 CPU,仅在 rank 0 上。当最终保存 state_dict 时,它将仅在 rank 0 上填充并包含 CPU 张量。这避免了模型大于单个 GPU 内存时可能出现的 OOM,并允许用户检查点模型,其大小大致为用户机器上可用的 CPU RAM。

此功能可以按如下方式运行

save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(
            model, StateDictType.FULL_STATE_DICT, save_policy
        ):
            cpu_state = model.state_dict()
if rank == 0:
 save_name = file_save_name + "-" + time_of_run + "-" + currEpoch
 torch.save(cpu_state, save_name)

总结

在本教程中,我们介绍了 Pytorch 1.12 中可用的 FSDP 的许多新功能,并使用 HF T5 作为运行示例。使用适当的包装策略(特别是对于 Transformer 模型)、混合精度和反向预取应该会加快您的训练运行速度。此外,诸如在设备上初始化模型以及通过流式传输到 CPU 保存检查点等功能应有助于避免处理大型模型时出现 OOM 错误。

我们正在积极努力为 FSDP 的下一个版本添加新功能。如果您有反馈、功能请求、问题或在使用 FSDP 时遇到问题,请随时通过在 PyTorch Github 仓库中打开 issue 与我们联系。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题解答

查看资源