作者:Ankita De, Edward Wang (EcoF), Rohan Varma, Anjali Sridhar, Kartikay Khandelwal

引言

近年来,扩展模型规模已成为一个有前景的研究领域。在自然语言处理(NLP)领域,语言模型已从数亿参数(BERT)发展到数千亿参数(GPT-3),并在下游任务上表现出显著提升。业界也对大规模语言模型的扩展定律进行了广泛研究。在视觉领域也观察到类似的趋势,社区正转向基于 Transformer 的模型(如Vision TransformerMasked Auto Encoders)。很明显,文本、图像、视频等单一模态都从近期的规模扩展进展中获得了巨大收益,并且框架也迅速适应以支持更大的模型。

同时,多模态在研究中变得越来越重要,图像-文本检索、视觉问答、视觉对话和文本到图像生成等任务在现实世界应用中日益受到关注。训练大规模多模态模型是自然而然的下一步,我们已经看到该领域的许多努力,例如 OpenAI 的 CLIP、Google 的 Parti 和 Meta 的 CM3

在这篇博客中,我们展示了一个案例研究,演示了如何使用 PyTorch Distributed 中的技术将 FLAVA 扩展到 100 亿参数。FLAVA 是一个视觉和语言基础模型,可在 TorchMultimodal 中获取,它在单模态和多模态基准测试中都表现出竞争力。我们还在本博客中提供了相关的代码指针。运行扩展 FLAVA 示例脚本的说明可在此处找到:此处

扩展 FLAVA 概览

FLAVA 是一个基础多模态模型,它由基于 Transformer 的图像编码器和文本编码器组成,然后是一个基于 Transformer 的多模态融合模块。它使用多样化的损失在单模态和多模态数据上进行了预训练。这包括掩码语言、图像和多模态建模损失,这些损失要求模型从其上下文(自监督学习)中重建原始输入。它还使用图像文本匹配损失来处理对齐的图像-文本对的正负样本,以及 CLIP 风格的对比损失。除了多模态任务(如图像-文本检索)外,FLAVA 在单模态基准测试(NLP 的 GLUE 任务和视觉的图像分类)上也表现出竞争力。

原始的 FLAVA 模型拥有约 3.5 亿参数,并使用 ViT-B16 配置(来自Vision Transformer 论文)作为图像和文本编码器。多模态融合 Transformer 跟随单模态编码器,但层数减半。我们探索将每个编码器的大小增加到更大的 ViT 变体。

扩展的另一个方面是增加批处理大小的能力。FLAVA 使用了基于批内负样本的对比损失,这通常受益于大批处理大小(如此处研究所示)。训练效率或吞吐量通常在接近 GPU 可用内存所决定的最大可能批处理大小时达到最高(另请参见实验部分)。

下表显示了我们实验的不同模型配置。我们还在实验部分确定了每种配置在内存中能够容纳的最大批处理大小。

大致模型参数 隐藏层大小 MLP 大小 注意力头数 单模态层数 多模态层数 模型大小 (fp32)
3.5亿 (原始) 768 3072 12 12 6 1.33GB
9亿 1024 4096 16 24 12 3.48GB
18亿 1280 5120 16 32 16 6.66GB
27亿 1408 6144 16 40 20 10.3GB
48亿 1664 8192 16 48 24 18.1GB
100亿 2048 10240 16 64 40 38GB

优化概览

PyTorch 提供了多种原生技术来高效地扩展模型。在接下来的部分中,我们将介绍其中一些技术,并展示如何将它们应用于将 FLAVA 模型扩展到 100 亿参数。

分布式数据并行 (DDP)

分布式训练的一个常见起点是数据并行。数据并行将模型复制到每个工作进程(GPU),并在工作进程之间划分数据集。不同的工作进程并行处理不同的数据分区,并在更新模型权重之前同步它们的梯度(通过 AllReduce)。下图展示了数据并行处理单个示例的流程(前向传播、反向传播和权重更新步骤)

来源:https://engineering.fb.com/2021/07/15/open-source/fsdp/

PyTorch 提供了一个原生 API,DistributedDataParallel (DDP) 来实现数据并行,它可以用作模块包装器,如下所示。更多详情请参阅 PyTorch Distributed 文档

from torchmultimodal.models.flava.model import flava_model_for_pretraining
import torch
import torch.distributed as dist

model = flava_model_for_pretraining().cuda()
# Initialize PyTorch Distributed process groups
# Please see https://pytorch.ac.cn/tutorials/intermediate/dist_tuto.html for details
dist.init_process_group(backend=”nccl”)
# Wrap model in DDP
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[torch.cuda.current_device()])

完全分片数据并行 (FSDP)

训练应用程序的 GPU 内存使用量大致可以分解为模型输入、中间激活(梯度计算所需)、模型参数、梯度和优化器状态。扩展模型通常会增加这些元素。使用 DDP 扩展模型最终可能导致内存不足问题,因为单个 GPU 的内存不足以在所有工作进程上复制参数、梯度和优化器状态。

为了减少这种复制并节省 GPU 内存,我们可以将模型参数、梯度和优化器状态分片到所有工作进程中,每个工作进程仅管理一个分片。这项技术由微软开发的 ZeRO-3 方法推广开来。PyTorch 提供了这种方法的原生实现,即 FullyShardedDataParallel (FSDP) API,它在 PyTorch 1.12 中作为 Beta 功能发布。在模块的前向和反向传播过程中,FSDP 根据计算需要解除模型参数的分片(使用 All-Gather),并在计算后重新分片。它使用 Reduce-Scatter 集合操作同步梯度,以确保分片梯度得到全局平均。下图详细介绍了用 FSDP 包装的模型的前向和反向传播流程。

来源:https://engineering.fb.com/2021/07/15/open-source/fsdp/

要使用 FSDP,需要使用该 API 包装模型的子模块,以控制何时对特定子模块进行分片或解除分片。FSDP 提供了一个可直接使用的自动包装 API(参见 auto_wrap_policy 参数),以及多种包装策略编写自定义策略的能力。

以下示例演示了如何使用 FSDP 包装 FLAVA 模型。我们将自动包装策略指定为 transformer_auto_wrap_policy。这将把单个 Transformer 层(TransformerEncoderLayer)、图像 Transformer(ImageTransformer)、文本编码器(BERTTextEncoder)和多模态编码器(FLAVATransformerWithoutEmbeddings)包装成独立的 FSDP 单元。这使用了一种递归包装方法来实现高效的内存管理。例如,在单个 Transformer 层的前向或反向传播完成后,其参数会被释放,从而释放内存并降低峰值内存使用量。

FSDP 还提供了许多可配置选项来调整应用程序的性能。例如,在我们的用例中,我们说明了如何使用新的 limit_all_gathers 标志,它防止模型参数过早地进行 All-Gather 操作,从而减轻应用程序的内存压力。我们鼓励用户尝试使用此标志,它可以潜在地提高具有高活跃内存使用量的应用程序的性能。

import torch
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from torchmultimodal.models.flava.model import flava_model_for_pretraining
from torchmultimodal.models.flava.text_encoder import BertTextEncoder
from torchmultimodal.models.flava.image_encoder import ImageTransformer
from torchmultimodal.models.flava.transformer import FLAVATransformerWithoutEmbeddings
from torchmultimodal.modules.layers.transformer import TransformerEncoderLayer

model = flava_model_for_pretraining().cuda()
dist.init_process_group(backend=”nccl”)

model = FSDP(
               model,
               device_id=torch.cuda.current_device(),
               auto_wrap_policy=partial(
                   transformer_auto_wrap_policy,
                   transformer_layer_cls={
                       TransformerEncoderLayer,
                       ImageTransformer,
                       BERTTextEncoder,
                       FLAVATransformerWithoutEmbeddings
                   },
               ),
               limit_all_gathers=True,
           )

激活检查点 (Activation Checkpointing)

如上所述,中间激活、模型参数、梯度和优化器状态都会对整体 GPU 内存使用量产生影响。FSDP 可以减少后三者的内存消耗,但不能减少激活消耗的内存。激活使用的内存随着批处理大小或隐藏层数量的增加而增加。激活检查点是一项技术,它通过在反向传播期间重新计算激活而不是将其保存在内存中来减少特定检查点模块的内存使用量。例如,我们将激活检查点应用于 27 亿参数模型后,观察到前向传播后峰值活跃内存减少了约 4 倍。

PyTorch 提供了一个基于包装器的激活检查点 API。特别是,checkpoint_wrapper 允许用户对单个模块进行检查点包装,而 apply_activation_checkpointing 允许用户指定一个策略来对整体模块内的子模块进行检查点包装。这两个 API 都可以应用于大多数模型,因为它们不需要修改模型定义代码。但是,如果需要对检查点段进行更精细的控制,例如对模块内的特定函数进行检查点,则可以利用函数式 torch.utils.checkpoint API,但这需要修改模型代码。将激活检查点包装器应用于单个 FLAVA Transformer 层(由 TransformerEncoderLayer 表示)的示例如下所示。有关激活检查点的详细说明,请参阅 PyTorch 文档中的描述。

from torchmultimodal.models.flava.model import flava_model_for_pretraining
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import apply_activation_checkpointing, checkpoint_wrapper, CheckpointImpl
from torchmultimodal.modules.layers.transformer import TransformerEncoderLayer

model = flava_model_for_pretraining()
checkpoint_tformer_layers_policy = lambda submodule: isinstance(submodule, TransformerEncoderLayer)

apply_activation_checkpointing(
               model,
               checkpoint_wrapper_fn=checkpoint_wrapper,
               check_fn=checkpoint_tformer_layers_policy,
           )

结合使用,如上所示,通过对 FLAVA Transformer 层应用激活检查点并对整个模型应用 FSDP 包装,我们能够将 FLAVA 扩展到 100 亿参数。

实验

我们对上一节中不同优化方法对系统性能的影响进行了实证研究。在我们所有的实验中,我们使用一个带有 8 块 A100 40GB GPU 的节点,并运行 1000 次迭代的预训练。所有运行都使用了 PyTorch 的 自动混合精度 和 bfloat16 数据类型。TensorFloat32 格式也已启用,以提高 A100 上的 matmul(矩阵乘法)性能。我们将吞吐量定义为每秒处理的项目(文本或图像)的平均数量(我们在测量吞吐量时忽略了前 100 次迭代,以排除预热时间)。我们将训练收敛及其对下游任务指标的影响留待未来研究。

图 1 绘制了每种模型配置和优化方法的吞吐量,包括本地批处理大小为 8 的情况,以及在单个节点上可能达到的最大批处理大小的情况。某种优化方法在某个模型变体上没有数据点表示该模型无法在单个节点上进行训练。

图 2 绘制了每种优化方法下每个工作进程可能达到的最大批处理大小。我们观察到以下几点:

  1. 模型规模扩展:DDP 只能在单个节点上容纳 3.5 亿和 9 亿参数的模型。使用 FSDP,由于内存节省,我们能够训练比 DDP 大约 3 倍的模型(即 18 亿和 27 亿参数的变体)。将激活检查点 (AC) 与 FSDP 结合使用,能够训练更大的模型,相比 DDP 规模大约提高 10 倍(即 48 亿和 100 亿参数的变体)。
  2. 吞吐量
    • 对于较小的模型,在固定批处理大小为 8 时,DDP 的吞吐量略高于或等于 FSDP,这可以用 FSDP 所需的额外通信来解释。FSDP 与 AC 结合时吞吐量最低。这是因为 AC 在反向传播期间重新运行检查点的前向传播,以额外的计算换取内存节省。然而,在 27 亿参数模型的情况下,FSDP + AC 的吞吐量实际上高于单独使用 FSDP。这是因为即使在批处理大小为 8 时,使用 FSDP 的 27 亿模型也接近内存限制,从而触发 CUDA malloc 重试,这往往会减慢训练速度。AC 有助于减轻内存压力并避免重试。
    • 对于 DDP 和 FSDP + AC,每种模型的吞吐量随批处理大小的增加而增加。对于单独使用 FSDP 的情况,对于较小的变体也是如此。然而,对于 18 亿和 27 亿参数的模型,我们观察到在增加批处理大小时吞吐量会下降。一个潜在的原因,如上所述,是在内存限制下,PyTorch 的 CUDA 内存管理可能不得不重试 cudaMalloc 调用和/或运行昂贵的碎片整理步骤以找到空闲内存块来处理工作负载的内存需求,这可能导致训练速度变慢。
    • 对于只能使用 FSDP 训练的较大模型(18 亿、27 亿、48 亿),实现最高吞吐量的设置是 FSDP + AC 并扩展到最大批处理大小。对于 100 亿参数的模型,我们观察到在较小批处理大小和最大批处理大小下的吞吐量几乎相等。这可能与直觉相反,因为 AC 会增加计算量,并且达到最大批处理大小可能由于在 CUDA 内存限制下运行而导致昂贵的碎片整理操作。然而,对于这些大型模型,批处理大小的增加足以掩盖这种开销。

图 1:不同配置下的训练吞吐量

  1. 批处理大小:单独使用 FSDP 可以实现比 DDP 略大的批处理大小。对于 3.5 亿参数模型,使用 FSDP + AC 可以实现比 DDP 大约 3 倍的批处理大小;对于 9 亿参数模型,可以实现大约 5.5 倍的批处理大小。即使对于 100 亿参数模型,最大批处理大小约为 20,这相当不错。这实质上允许使用更少的 GPU 实现更大的全局批处理大小,这对于对比学习任务特别有用。

图 2:不同配置下可能达到的最大本地批处理大小

结论

随着世界迈向多模态基础模型,模型参数的扩展和高效训练正成为关注的焦点。PyTorch 生态系统旨在通过为研究社区提供用于训练和扩展多模态模型的不同工具来加速该领域的创新。通过 FLAVA,我们展示了一个用于多模态理解的模型扩展示例。未来,我们计划增加对其他类型模型(如多模态生成模型)的支持,并展示它们的扩展因素。我们还希望自动化许多这些扩展和内存节省技术(例如分片和激活检查点),以减少用户为达到所需规模和最大训练吞吐量所需的实验量。

参考文献