博客

在 TorchMultimodal 中使用 Pytorch Distributed 扩展多模态基础模型

简介

近年来,扩展模型规模已成为一个充满前景的研究领域。在自然语言处理(NLP)领域,语言模型已从数亿参数(BERT)发展到数千亿参数(GPT-3),并在下游任务中表现出显著的性能提升。大型语言模型的扩展定律 (Scaling laws) 也已在业界得到了广泛研究。在视觉领域也可以观察到类似的趋势,社区正在转向基于 Transformer 的模型(如 Vision TransformerMasked Auto Encoders)。显而易见,文本、图像、视频等单一模态都极大地受益于近期在模型规模上的进步,而各种框架也迅速调整以适应更大的模型。

与此同时,多模态在研究中的重要性日益凸显,图像-文本检索、视觉问答、视觉对话和文生图等任务在现实世界应用中正获得广泛关注。训练大规模多模态模型是自然而然的下一步,我们已经看到该领域涌现出多项成果,例如 OpenAI 的 CLIP、Google 的 Parti 以及 Meta 的 CM3

在本篇博客中,我们展示了一个案例研究,说明如何利用 PyTorch Distributed 的技术将 FLAVA 扩展到 10B 参数。FLAVA 是一个视觉和语言基础模型,可在 TorchMultimodal 中获取,该模型在单模态和多模态基准测试中均展现出了极具竞争力的性能。我们还在博客中提供了相关的代码指南。运行用于扩展 FLAVA 的示例脚本的说明可点击此处查看。

FLAVA 扩展概述

FLAVA 是一种多模态基础模型,由基于 Transformer 的图像和文本编码器,以及随后的基于 Transformer 的多模态融合模块组成。它在单模态和多模态数据上使用多种损失函数进行预训练。这包括掩码语言、掩码图像和多模态建模损失,这些损失要求模型根据上下文重构原始输入(自监督学习)。它还使用基于对齐的图像-文本对正负样本的图像文本匹配损失,以及 CLIP 风格的对比损失。除了多模态任务(如图像-文本检索)外,FLAVA 在单模态基准测试(NLP 的 GLUE 任务和视觉的图像分类)中也展示了具有竞争力的性能。

原始的 FLAVA 模型拥有约 3.5 亿(350M)参数,图像和文本编码器采用了 ViT-B16 配置(出自 Vision Transformer 论文)。多模态融合 Transformer 跟随在单模态编码器之后,但层数减半。我们探索了将每个编码器的大小增加到更大的 ViT 变体。

扩展的另一个方面是增加批量(batch size)的能力。FLAVA 使用批内负样本的对比损失,这通常受益于较大的批量(如此处所研究)。此外,当在 GPU 可用显存允许的最大批量附近运行时,通常也能实现最高的训练效率或吞吐量(另请参阅实验部分)。

下表展示了我们实验的不同模型配置。我们在实验部分确定了每种配置在内存中能够容纳的最大批量。

模型参数量近似值隐藏层大小 (Hidden size)MLP 大小注意力头数 (Heads)单模态层数多模态层数模型大小 (fp32)
350M (原始)7683072121261.33GB
900M102440961624123.48GB
1.8B128051201632166.66GB
2.7B1408614416402010.3GB
4.8B1664819216482418.1GB
10B20481024016644038GB

优化概述

PyTorch 提供了多种原生技术来高效扩展模型。在接下来的章节中,我们将介绍其中一些技术,并展示如何将它们应用于将 FLAVA 模型扩展至 100 亿(10B)参数。

分布式数据并行 (Distributed Data Parallel)

分布式训练的一个常见起点是数据并行。数据并行会在每个工作节点(GPU)上复制模型,并将数据集划分到各个工作节点上。不同的工作节点并行处理不同的数据分区,并在更新模型权重之前同步它们的梯度(通过 all-reduce)。下图展示了数据并行处理单个示例的流程(前向传播、反向传播和权重更新步骤)。

来源:https://engineering.fb.com/2021/07/15/open-source/fsdp/

PyTorch 提供了一个原生 API DistributedDataParallel (DDP) 来实现数据并行,它可以作为模块包装器使用,如下所示。请参阅 PyTorch Distributed 文档了解更多详细信息。

from torchmultimodal.models.flava.model import flava_model_for_pretraining
import torch
import torch.distributed as dist

model = flava_model_for_pretraining().cuda()
# Initialize PyTorch Distributed process groups
# Please see https://pytorch.ac.cn/tutorials/intermediate/dist_tuto.html for details
dist.init_process_group(backend=”nccl”)
# Wrap model in DDP
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[torch.cuda.current_device()])

完全分片数据并行 (Fully Sharded Data Parallel)

训练应用程序的 GPU 内存使用量大致可以分解为:模型输入、中间激活值(梯度计算所需)、模型参数、梯度和优化器状态。扩展模型通常会增加这些元素中的每一个。使用 DDP 扩展模型时,如果单个 GPU 的内存不足,由于它在所有工作节点上复制参数、梯度和优化器状态,最终可能会导致内存溢出(OOM)问题。

为了减少这种复制并节省 GPU 内存,我们可以将模型参数、梯度和优化器状态分片到所有工作节点上,每个工作节点仅管理一个分片。这种技术由微软开发的 ZeRO-3 方法推广。PyTorch 原生实现了一种名为 FullyShardedDataParallel (FSDP) 的 API,该 API 于 PyTorch 1.12 中作为测试功能发布。在模块的前向和反向传播过程中,FSDP 会根据计算需要取消模型参数的分片(使用 all-gather),并在计算后重新分片。它使用 reduce-scatter 集合通信来同步梯度,以确保分片梯度在全局范围内取平均值。FSDP 包装模型的具体前后向传播流程如下所述。

来源:https://engineering.fb.com/2021/07/15/open-source/fsdp/

要使用 FSDP,需要使用该 API 包装模型的子模块,以控制何时对特定子模块进行分片或取消分片。FSDP 提供了一个自动包装 API(请参见 auto_wrap_policy 参数),既可以直接使用,也提供了多种 包装策略,并支持编写您自己的策略

以下示例演示了如何用 FSDP 包装 FLAVA 模型。我们将自动包装策略指定为 transformer_auto_wrap_policy。这将把单个 Transformer 层(TransformerEncoderLayer)、图像 Transformer(ImageTransformer)、文本编码器(BERTTextEncoder)和多模态编码器(FLAVATransformerWithoutEmbeddings)包装为独立的 FSDP 单元。这采用了递归包装方法以实现高效的内存管理。例如,在单个 Transformer 层的前向或反向传播完成后,其参数会被丢弃,从而释放内存,减少峰值内存使用。

FSDP 还提供了许多可配置的选项来调整应用程序的性能。例如,在我们的用例中,我们说明了新的 limit_all_gathers 标志的使用,它防止过早地 all-gather 模型参数,从而减轻应用程序的内存压力。我们鼓励用户尝试使用此标志,这可能会提高内存占用较高的应用程序的性能。

import torch
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from torchmultimodal.models.flava.model import flava_model_for_pretraining
from torchmultimodal.models.flava.text_encoder import BertTextEncoder
from torchmultimodal.models.flava.image_encoder import ImageTransformer
from torchmultimodal.models.flava.transformer import FLAVATransformerWithoutEmbeddings
from torchmultimodal.modules.layers.transformer import TransformerEncoderLayer

model = flava_model_for_pretraining().cuda()
dist.init_process_group(backend=”nccl”)

model = FSDP(
               model,
               device_id=torch.cuda.current_device(),
               auto_wrap_policy=partial(
                   transformer_auto_wrap_policy,
                   transformer_layer_cls={
                       TransformerEncoderLayer,
                       ImageTransformer,
                       BERTTextEncoder,
                       FLAVATransformerWithoutEmbeddings
                   },
               ),
               limit_all_gathers=True,
           )

激活检查点 (Activation Checkpointing)

如上所述,中间激活值、模型参数、梯度和优化器状态构成了整体 GPU 内存使用量。FSDP 可以减少后三者导致的内存消耗,但不会减少激活值所占用的内存。激活值占用的内存随着批量或隐藏层数量的增加而增加。激活检查点是一种通过在反向传播过程中重新计算激活值(而不是将它们保存在内存中)来减少内存占用的技术。例如,我们观察到在 2.7B 参数模型上应用激活检查点后,前向传播后的峰值活跃内存减少了约 4 倍。

PyTorch 提供了基于包装器的激活检查点 API。具体而言,checkpoint_wrapper 允许用户使用检查点包装单个模块,而 apply_activation_checkpointing 允许用户指定一种策略,在整体模块内包装特定模块。这两个 API 都可以应用于大多数模型,因为它们不需要修改模型定义代码。但是,如果需要对检查点片段进行更细粒度的控制(例如检查点模块内的特定函数),可以利用函数式 torch.utils.checkpoint API,但这需要修改模型代码。下面展示了将激活检查点包装器应用于单个 FLAVA Transformer 层(表示为 TransformerEncoderLayer)。有关激活检查点的详细描述,请参阅 PyTorch 文档中的描述。

from torchmultimodal.models.flava.model import flava_model_for_pretraining
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import apply_activation_checkpointing, checkpoint_wrapper, CheckpointImpl
from torchmultimodal.modules.layers.transformer import TransformerEncoderLayer

model = flava_model_for_pretraining()
checkpoint_tformer_layers_policy = lambda submodule: isinstance(submodule, TransformerEncoderLayer)

apply_activation_checkpointing(
               model,
               checkpoint_wrapper_fn=checkpoint_wrapper,
               check_fn=checkpoint_tformer_layers_policy,
           )

综合使用上述方法:用激活检查点包装 FLAVA Transformer 层,并用 FSDP 包装整体模型,我们能够将 FLAVA 扩展到 10B 参数。

实验

我们针对上一节中不同优化技术对系统性能的影响进行了实证研究。对于所有实验,我们使用包含 8 个 A100 40GB GPU 的单个节点,并运行 1000 次迭代的预训练。所有运行均使用了 PyTorch 的 自动混合精度 (AMP) 以及 bfloat16 数据类型。同时启用了 TensorFloat32 格式以提高 A100 上的矩阵乘法性能。我们将吞吐量定义为每秒处理的平均项目数(文本或图像)(我们在计算吞吐量时忽略了前 100 次迭代,以考虑热身因素)。我们将训练至收敛及其对下游任务指标的影响留作未来研究的领域。

图 1 绘制了每种模型配置和优化技术的吞吐量,分别以 8 的局部批量和节点上可能的最大批量进行绘制。如果某个模型变体在某种优化下缺失数据点,则表示该模型无法在单节点上进行训练。

图 2 绘制了每种优化技术下每个工作节点可能的最大批量。我们观察到几点:

  1. 扩展模型大小:DDP 只能在节点上容纳 350M 和 900M 的模型。使用 FSDP,由于内存节省,我们能够训练比 DDP 大约 3 倍的模型(即 1.8B 和 2.7B 变体)。将激活检查点 (AC) 与 FSDP 结合,能够训练更大的模型,规模可达 DDP 的约 10 倍(即 4.8B 和 10B 变体)。
  2. 吞吐量
    • 对于较小的模型规模,在 8 的固定批量下,DDP 的吞吐量略高于或等于 FSDP,这可以通过 FSDP 所需的额外通信来解释。对于 FSDP 和 AC 的组合,吞吐量最低。这是因为 AC 在反向传播期间重新运行了检查点前向传播,以额外的计算代价换取了内存节省。然而,在 2.7B 模型的情况下,FSDP + AC 的吞吐量实际上比单纯使用 FSDP 更高。这是因为 2.7B 模型在 FSDP 下即使在批量为 8 时也接近内存极限,触发了 CUDA malloc 重试,这往往会减慢训练速度。AC 有助于减轻内存压力,从而避免了重试。
    • 对于 DDP 和 FSDP + AC,吞吐量随着每个模型批量的增加而增加。对于单纯的 FSDP,这对于较小的变体是成立的。然而,对于 1.8B 和 2.7B 参数的模型,我们观察到在增加批量时吞吐量出现了下降。如上所述,一个潜在的原因是:在内存极限时,PyTorch 的 CUDA 内存管理可能必须重试 cudaMalloc 调用和/或运行昂贵的碎片整理步骤,以查找空闲内存块来处理工作负载的内存需求,这可能导致训练减速。
    • 对于只能使用 FSDP 训练的大型模型(1.8B、2.7B、4.8B),实现最高吞吐量的设置是 FSDP + AC 扩展至最大批量。对于 10B,我们观察到较小批量和最大批量下的吞吐量几乎相等。这可能违背直觉,因为 AC 导致计算增加,并且最大化批量可能会因在 CUDA 内存极限下运行而导致昂贵的碎片整理操作。然而,对于这些大型模型,批量增加的幅度足以掩盖这种开销。

图 1:不同配置的训练吞吐量

  1. 批量大小:与 DDP 相比,单独使用 FSDP 可实现略高的批量。对于 350M 参数模型,使用 FSDP + AC 可实现约为 DDP 3 倍的批量;对于 900M 参数模型,可实现约 5.5 倍。即使对于 10B 模型,最大批量约为 20,这已经相当不错了。这实质上可以使用更少的 GPU 实现更大的全局批量,这对对比学习任务特别有用。

图 2:不同配置下可能的最大局部批量

结论

随着全球向多模态基础模型迈进,扩展模型参数和高效训练正成为关注的焦点。PyTorch 生态系统旨在通过为研究社区提供不同的工具(无论是用于训练还是扩展多模态模型)来加速该领域的创新。通过 FLAVA,我们提供了一个扩展多模态理解模型的示例。未来,我们计划增加对其他模型类型的支持,例如多模态生成模型,并展示它们的扩展因子。我们还希望实现许多此类扩展和内存节省技术(如分片和激活检查点)的自动化,以减少用户为达到预期规模和最大训练吞吐量所需的实验量。

参考资料