• 教程 >
  • 使用完全分片数据并行 (FSDP) 的高级模型训练
快捷方式

使用完全分片数据并行 (FSDP) 的高级模型训练

作者: Hamid Shojanazeri, Less Wright, Rohan Varma, Yanli Zhao

本教程介绍了 PyTorch 1.12 版本中完全分片数据并行 (FSDP) 的更高级功能。要熟悉 FSDP,请参考 FSDP 入门教程

在本教程中,我们将使用 FSDP 微调 HuggingFace (HF) T5 模型,以文本摘要为例进行演示。

该示例使用 Wikihow,为了简单起见,我们将展示在单节点 P4dn 实例(配备 8 个 A100 GPU)上的训练。我们很快将在 PyTorch Medium 频道上发布关于多节点集群上大规模 FSDP 训练的博客文章,敬请关注。

FSDP 是一个生产就绪的软件包,专注于易用性、性能和长期支持。FSDP 的主要优势之一是减少每个 GPU 上的内存占用。与 DDP 相比,这使得可以使用更少的总内存来训练更大的模型,并利用计算和通信的重叠来有效地训练模型。这种降低的内存压力可以用于训练更大的模型或增加批量大小,从而可能帮助提高整体训练吞吐量。你可以 在此处 阅读更多关于 PyTorch FSDP 的信息。

本教程中的 FSDP 功能

  • Transformer 自动包装策略

  • 混合精度

  • 在设备上初始化 FSDP 模型

  • 分片策略

  • 反向预取

  • 通过流式传输到 CPU 保存模型检查点

FSDP 工作原理回顾

在高层次上,FDSP 的工作原理如下

在构造函数中

  • 分片模型参数,每个 rank 只保留自己的分片

在前向传递中

  • 运行 all_gather 从所有 rank 收集所有分片,以恢复此 FSDP 单元的完整参数。运行前向计算

  • 丢弃刚刚收集的非自有参数分片,以释放内存

在反向传递中

  • 运行 all_gather 从所有 rank 收集所有分片,以恢复此 FSDP 单元中的完整参数。运行反向计算

  • 丢弃非自有参数以释放内存。

  • 运行 reduce_scatter 以同步梯度

微调 HF T5

HF T5 预训练模型有四种不同尺寸,从小型的 6000 万参数到 XXL 的 110 亿参数。在本教程中,我们将演示使用 FSDP 微调 T5 3B 以进行文本摘要,使用 WikiHow 数据集。本教程的主要重点是突出显示 FSDP 中可用于训练超过 30 亿参数的大规模模型的不同可用功能。此外,我们还介绍了基于 Transformer 模型的特定功能。本教程的代码可在 Pytorch 示例中找到。

设置

1.1 安装 PyTorch Nightlies

我们将安装 PyTorch Nightlies,因为某些功能(如激活检查点)在 Nightlies 中可用,并将添加到 1.12 之后的下一个 PyTorch 版本中。

pip3 install --pre torch torchvision torchaudio -f https://download.pytorch.org/whl/nightly/cu113/torch_nightly.html

1.2 数据集设置

请创建一个 data 文件夹,从 wikihowAll.csvwikihowSep.cs 下载 WikiHow 数据集,并将它们放在 data 文件夹中。我们将使用来自 summarization_dataset 的 wikihow 数据集。

接下来,我们将以下代码片段添加到 Python 脚本 “T5_training.py” 中。

注意

本教程的完整源代码可在 PyTorch 示例中找到。

1.3 导入必要的包

import os
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from transformers import AutoTokenizer, GPT2TokenizerFast
from transformers import T5Tokenizer, T5ForConditionalGeneration
import functools
from torch.optim.lr_scheduler import StepLR
import torch.nn.functional as F
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
from transformers.models.t5.modeling_t5 import T5Block

from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
 checkpoint_wrapper,
 CheckpointImpl,
 apply_activation_checkpointing_wrapper)

from torch.distributed.fsdp import (
    FullyShardedDataParallel as FSDP,
    MixedPrecision,
    BackwardPrefetch,
    ShardingStrategy,
    FullStateDictConfig,
    StateDictType,
)
from torch.distributed.fsdp.wrap import (
    transformer_auto_wrap_policy,
    enable_wrap,
    wrap,
)
from functools import partial
from torch.utils.data import DataLoader
from pathlib import Path
from summarization_dataset import *
from transformers.models.t5.modeling_t5 import T5Block
from typing import Type
import time
import tqdm
from datetime import datetime

1.4 分布式训练设置。这里我们使用两个辅助函数来初始化分布式训练的进程,然后在训练完成后进行清理。在本教程中,我们将使用 torch elastic,使用 torchrun,它将自动设置 worker 的 RANKWORLD_SIZE

def setup():
    # initialize the process group
    dist.init_process_group("nccl")

def cleanup():
    dist.destroy_process_group()

2.1 设置 HuggingFace T5 模型

def setup_model(model_name):
    model = T5ForConditionalGeneration.from_pretrained(model_name)
    tokenizer =  T5Tokenizer.from_pretrained(model_name)
    return model, tokenizer

我们还在此处添加了几个辅助函数,用于日期和格式化内存指标。

def get_date_of_run():
    """create date and time for file save uniqueness
    example: 2022-05-07-08:31:12_PM'
    """
    date_of_run = datetime.now().strftime("%Y-%m-%d-%I:%M:%S_%p")
    print(f"--> current date and time of run = {date_of_run}")
    return date_of_run

def format_metrics_to_gb(item):
    """quick function to format numbers to gigabyte and round to 4 digit precision"""
    metric_num = item / g_gigabyte
    metric_num = round(metric_num, ndigits=4)
    return metric_num

2.2 定义训练函数

def train(args, model, rank, world_size, train_loader, optimizer, epoch, sampler=None):
    model.train()
    local_rank = int(os.environ['LOCAL_RANK'])
    fsdp_loss = torch.zeros(2).to(local_rank)

    if sampler:
        sampler.set_epoch(epoch)
    if rank==0:
        inner_pbar = tqdm.tqdm(
            range(len(train_loader)), colour="blue", desc="r0 Training Epoch"
        )
    for batch in train_loader:
        for key in batch.keys():
            batch[key] = batch[key].to(local_rank)
        optimizer.zero_grad()
        output = model(input_ids=batch["source_ids"],attention_mask=batch["source_mask"],labels=batch["target_ids"] )
        loss = output["loss"]
        loss.backward()
        optimizer.step()
        fsdp_loss[0] += loss.item()
        fsdp_loss[1] += len(batch)
        if rank==0:
            inner_pbar.update(1)

    dist.all_reduce(fsdp_loss, op=dist.ReduceOp.SUM)
    train_accuracy = fsdp_loss[0] / fsdp_loss[1]


    if rank == 0:
        inner_pbar.close()
        print(
                f"Train Epoch: \t{epoch}, Loss: \t{train_accuracy:.4f}"
            )
    return train_accuracy

2.3 定义验证函数

def validation(model, rank, world_size, val_loader):
    model.eval()
    correct = 0
    local_rank = int(os.environ['LOCAL_RANK'])
    fsdp_loss = torch.zeros(3).to(local_rank)
    if rank == 0:
        inner_pbar = tqdm.tqdm(
            range(len(val_loader)), colour="green", desc="Validation Epoch"
        )
    with torch.no_grad():
        for batch in val_loader:
            for key in batch.keys():
                batch[key] = batch[key].to(local_rank)
            output = model(input_ids=batch["source_ids"],attention_mask=batch["source_mask"],labels=batch["target_ids"])
            fsdp_loss[0] += output["loss"].item()  # sum up batch loss
            fsdp_loss[1] += len(batch)

            if rank==0:
                inner_pbar.update(1)

    dist.all_reduce(fsdp_loss, op=dist.ReduceOp.SUM)
    val_loss = fsdp_loss[0] / fsdp_loss[1]
    if rank == 0:
        inner_pbar.close()
        print(f"Validation Loss: {val_loss:.4f}")
    return val_loss

2.4 定义一个分布式训练函数,该函数将模型包装在 FSDP 中

def fsdp_main(args):

    model, tokenizer = setup_model("t5-base")

    local_rank = int(os.environ['LOCAL_RANK'])
    rank = int(os.environ['RANK'])
    world_size = int(os.environ['WORLD_SIZE'])


    dataset = load_dataset('wikihow', 'all', data_dir='data/')
    print(dataset.keys())
    print("Size of train dataset: ", dataset['train'].shape)
    print("Size of Validation dataset: ", dataset['validation'].shape)


    #wikihow(tokenizer, type_path, num_samples, input_length, output_length, print_text=False)
    train_dataset = wikihow(tokenizer, 'train', 1500, 512, 150, False)
    val_dataset = wikihow(tokenizer, 'validation', 300, 512, 150, False)

    sampler1 = DistributedSampler(train_dataset, rank=rank, num_replicas=world_size, shuffle=True)
    sampler2 = DistributedSampler(val_dataset, rank=rank, num_replicas=world_size)

    setup()


    train_kwargs = {'batch_size': args.batch_size, 'sampler': sampler1}
    test_kwargs = {'batch_size': args.test_batch_size, 'sampler': sampler2}
    cuda_kwargs = {'num_workers': 2,
                    'pin_memory': True,
                    'shuffle': False}
    train_kwargs.update(cuda_kwargs)
    test_kwargs.update(cuda_kwargs)

    train_loader = torch.utils.data.DataLoader(train_dataset,**train_kwargs)
    val_loader = torch.utils.data.DataLoader(val_dataset, **test_kwargs)

    t5_auto_wrap_policy = functools.partial(
        transformer_auto_wrap_policy,
        transformer_layer_cls={
            T5Block,
        },
    )
    sharding_strategy: ShardingStrategy = ShardingStrategy.SHARD_GRAD_OP #for Zero2 and FULL_SHARD for Zero3
    torch.cuda.set_device(local_rank)


    #init_start_event = torch.cuda.Event(enable_timing=True)
    #init_end_event = torch.cuda.Event(enable_timing=True)

    #init_start_event.record()

    bf16_ready = (
    torch.version.cuda
    and torch.cuda.is_bf16_supported()
    and LooseVersion(torch.version.cuda) >= "11.0"
    and dist.is_nccl_available()
    and nccl.version() >= (2, 10)
    )

    if bf16_ready:
        mp_policy = bfSixteen
    else:
        mp_policy = None # defaults to fp32

    # model is on CPU before input to FSDP
    model = FSDP(model,
        auto_wrap_policy=t5_auto_wrap_policy,
        mixed_precision=mp_policy,
        #sharding_strategy=sharding_strategy,
        device_id=torch.cuda.current_device())

    optimizer = optim.AdamW(model.parameters(), lr=args.lr)

    scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
    best_val_loss = float("inf")
    curr_val_loss = float("inf")
    file_save_name = "T5-model-"

    if rank == 0:
        time_of_run = get_date_of_run()
        dur = []
        train_acc_tracking = []
        val_acc_tracking = []
        training_start_time = time.time()

    if rank == 0 and args.track_memory:
        mem_alloc_tracker = []
        mem_reserved_tracker = []

    for epoch in range(1, args.epochs + 1):
        t0 = time.time()
        train_accuracy = train(args, model, rank, world_size, train_loader, optimizer, epoch, sampler=sampler1)
        if args.run_validation:
            curr_val_loss = validation(model, rank, world_size, val_loader)
        scheduler.step()

        if rank == 0:

            print(f"--> epoch {epoch} completed...entering save and stats zone")

            dur.append(time.time() - t0)
            train_acc_tracking.append(train_accuracy.item())

            if args.run_validation:
                val_acc_tracking.append(curr_val_loss.item())

            if args.track_memory:
                mem_alloc_tracker.append(
                    format_metrics_to_gb(torch.cuda.memory_allocated())
                )
                mem_reserved_tracker.append(
                    format_metrics_to_gb(torch.cuda.memory_reserved())
                )
            print(f"completed save and stats zone...")

        if args.save_model and curr_val_loss < best_val_loss:

            # save
            if rank == 0:
                print(f"--> entering save model state")

            save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
            with FSDP.state_dict_type(
                model, StateDictType.FULL_STATE_DICT, save_policy
            ):
                cpu_state = model.state_dict()
            #print(f"saving process: rank {rank}  done w state_dict")


            if rank == 0:
                print(f"--> saving model ...")
                currEpoch = (
                    "-" + str(epoch) + "-" + str(round(curr_val_loss.item(), 4)) + ".pt"
                )
                print(f"--> attempting to save model prefix {currEpoch}")
                save_name = file_save_name + "-" + time_of_run + "-" + currEpoch
                print(f"--> saving as model name {save_name}")

                torch.save(cpu_state, save_name)

        if curr_val_loss < best_val_loss:

            best_val_loss = curr_val_loss
            if rank==0:
                print(f"-->>>> New Val Loss Record: {best_val_loss}")

    dist.barrier()
    cleanup()

2.5 解析参数并设置主函数

if __name__ == '__main__':
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch T5 FSDP Example')
    parser.add_argument('--batch-size', type=int, default=4, metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size', type=int, default=4, metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs', type=int, default=2, metavar='N',
                        help='number of epochs to train (default: 3)')
    parser.add_argument('--lr', type=float, default=.002, metavar='LR',
                        help='learning rate (default: .002)')
    parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
                        help='Learning rate step gamma (default: 0.7)')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed', type=int, default=1, metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--track_memory', action='store_false', default=True,
                        help='track the gpu memory')
    parser.add_argument('--run_validation', action='store_false', default=True,
                        help='running the validation')
    parser.add_argument('--save-model', action='store_false', default=True,
                        help='For Saving the current Model')
    args = parser.parse_args()

    torch.manual_seed(args.seed)

    fsdp_main(args)

要使用 torchrun 运行训练

torchrun --nnodes 1 --nproc_per_node 4  T5_training.py

Transformer 包装策略

正如 之前的教程 中讨论的那样,auto_wrap_policy 是 FSDP 的功能之一,它可以轻松地自动分片给定的模型,并将模型、优化器和梯度分片放入不同的 FSDP 单元中。

对于某些架构(如 Transformer 编码器-解码器),模型的某些部分(如嵌入表)在编码器和解码器之间共享。在这种情况下,我们需要将嵌入表放在外部 FSDP 单元中,以便可以从编码器和解码器访问它。此外,通过注册 Transformer 的层类,可以使分片计划更具通信效率。在 PyTorch 1.12 中,FSDP 添加了此支持,现在我们有了 Transformer 的包装策略。

它可以按如下方式创建,其中 T5Block 表示 T5 Transformer 层类(包含 MHSA 和 FFN)。

t5_auto_wrap_policy = functools.partial(
        transformer_auto_wrap_policy,
        transformer_layer_cls={
            T5Block,
        },
    )
torch.cuda.set_device(local_rank)


model = FSDP(model,
    auto_wrap_policy=t5_auto_wrap_policy)

要查看包装后的模型,你可以轻松地打印模型并直观地检查分片和 FSDP 单元。

混合精度

FSDP 支持灵活的混合精度训练,允许任意降低精度类型(如 fp16 或 bfloat16)。目前,BFloat16 仅在 Ampere GPU 上可用,因此你需要在使用前确认本机支持。例如,在 V100 上,BFloat16 仍然可以运行,但由于它以非本机方式运行,可能会导致明显的减速。

要检查是否本机支持 BFloat16,你可以使用以下方法

bf16_ready = (
    torch.version.cuda
    and torch.cuda.is_bf16_supported()
    and LooseVersion(torch.version.cuda) >= "11.0"
    and dist.is_nccl_available()
    and nccl.version() >= (2, 10)
)

FSDP 中混合精度的优势之一是为参数、梯度和缓冲区提供对不同精度级别的精细控制,如下所示

fpSixteen = MixedPrecision(
    param_dtype=torch.float16,
    # Gradient communication precision.
    reduce_dtype=torch.float16,
    # Buffer precision.
    buffer_dtype=torch.float16,
)

bfSixteen = MixedPrecision(
    param_dtype=torch.bfloat16,
    # Gradient communication precision.
    reduce_dtype=torch.bfloat16,
    # Buffer precision.
    buffer_dtype=torch.bfloat16,
)

fp32_policy = MixedPrecision(
    param_dtype=torch.float32,
    # Gradient communication precision.
    reduce_dtype=torch.float32,
    # Buffer precision.
    buffer_dtype=torch.float32,
)

请注意,如果未指定特定类型(参数、reduce、buffer),则根本不会进行类型转换。

这种灵活性允许用户进行细粒度控制,例如仅将梯度通信设置为以降低的精度进行,而所有参数/缓冲区计算都以全精度完成。这在节点内通信是主要瓶颈,并且参数/缓冲区必须保持全精度以避免精度问题的情况下可能很有用。可以使用以下策略完成此操作

grad_bf16 = MixedPrecision(reduce_dtype=torch.bfloat16)

在 2.4 中,我们只需将相关的混合精度策略添加到 FSDP 包装器中

model = FSDP(model,
       auto_wrap_policy=t5_auto_wrap_policy,
       mixed_precision=bfSixteen)

在我们的实验中,我们观察到使用 BFloat16 进行训练可加速高达 4 倍,并在某些实验中内存减少约 30%,可用于增加批量大小。

在设备上初始化 FSDP 模型

在 1.12 中,FSDP 支持 device_id 参数,该参数旨在在 device_id 给定的设备上初始化输入 CPU 模块。当整个模型不适合单个 GPU,但适合主机的 CPU 内存时,这很有用。当指定 device_id 时,FSDP 会按每个 FSDP 单元将模型移动到指定的设备,从而避免 GPU OOM 问题,同时初始化速度比基于 CPU 的初始化快几倍

torch.cuda.set_device(local_rank)

 model = FSDP(model,
        auto_wrap_policy=t5_auto_wrap_policy,
        mixed_precision=bfSixteen,
        device_id=torch.cuda.current_device())

分片策略

FSDP 分片策略默认设置为完全分片模型参数,梯度和优化器状态在所有 rank 之间分片(也称为 Zero3 分片)。如果你有兴趣使用 Zero2 分片策略,其中仅分片优化器状态和梯度,则 FSDP 通过使用 “ShardingStrategy.SHARD_GRAD_OP” 而不是 “ShardingStrategy.FULL_SHARD” 将分片策略传递给 FSDP 初始化来支持此功能,如下所示

torch.cuda.set_device(local_rank)

 model = FSDP(model,
        auto_wrap_policy=t5_auto_wrap_policy,
        mixed_precision=bfSixteen,
        device_id=torch.cuda.current_device(),
        sharding_strategy=ShardingStrategy.SHARD_GRAD_OP # ZERO2)

这将减少 FSDP 中的通信开销,在这种情况下,它在前向传递后和整个反向传递过程中保持完整参数。

这节省了反向传播期间的 all_gather,因此通信更少,但内存占用更高。请注意,完整模型参数在反向传播结束时释放,all_gather 将在下一个前向传播中发生。

反向预取

反向预取设置控制何时应请求下一个 FSDP 单元的参数。通过将其设置为 BACKWARD_PRE,可以在当前单元的计算开始之前更早地开始请求和到达下一个 FSDP 单元的参数。这重叠了 all_gather 通信和梯度计算,可以提高训练速度,但会略微增加内存消耗。它可以按如下方式在 2.4 中的 FSDP 包装器中使用

torch.cuda.set_device(local_rank)

 model = FSDP(model,
        auto_wrap_policy=t5_auto_wrap_policy,
        mixed_precision=bfSixteen,
        device_id=torch.cuda.current_device(),
        backward_prefetch = BackwardPrefetch.BACKWARD_PRE)

backward_prefetch 有两种模式:BACKWARD_PREBACKWARD_POSTBACKWARD_POST 意味着在当前 FSDP 单元处理完成之前,不会请求下一个 FSDP 单元的参数,从而最大限度地减少内存开销。在某些情况下,使用 BACKWARD_PRE 可以将模型训练速度提高 2-10%,对于更大的模型,速度提升甚至更高。

模型检查点保存,通过流式传输到 Rank0 CPU

要使用 FULL_STATE_DICT 保存来保存模型检查点,这以与本地模型相同的方式保存模型,PyTorch 1.12 提供了一些实用程序来支持保存更大的模型。

首先,可以指定 FullStateDictConfig,允许仅在 rank 0 上填充 state_dict 并将其卸载到 CPU。

使用此配置时,FSDP 将 allgather 模型参数,一次一个地将它们卸载到 CPU,仅在 rank 0 上进行。当最终保存 state_dict 时,它将仅在 rank 0 上填充并包含 CPU 张量。这避免了模型大于单个 GPU 内存时可能发生的 OOM,并允许用户检查点模型,其大小大致为用户机器上可用的 CPU RAM。

此功能可以按如下方式运行

save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(
            model, StateDictType.FULL_STATE_DICT, save_policy
        ):
            cpu_state = model.state_dict()
if rank == 0:
 save_name = file_save_name + "-" + time_of_run + "-" + currEpoch
 torch.save(cpu_state, save_name)

总结

在本教程中,我们介绍了 Pytorch 1.12 中可用的 FSDP 的许多新功能,并使用 HF T5 作为运行示例。使用适当的包装策略(特别是对于 Transformer 模型),以及混合精度和反向预取,应该可以加快你的训练运行速度。此外,在设备上初始化模型以及通过流式传输到 CPU 保存检查点等功能应有助于避免处理大型模型时的 OOM 错误。

我们正在积极努力为下一个版本向 FSDP 添加新功能。如果你有反馈、功能请求、问题或在使用 FSDP 时遇到问题,请随时通过在 PyTorch Github 仓库 中打开 issue 与我们联系。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题解答

查看资源