跳转到主要内容
博客

在 TorchMultimodal 中使用 Pytorch Distributed 扩展多模态基础模型

简介

近年来,扩展模型规模已成为一个前景广阔的研究领域。在自然语言处理(NLP)领域,语言模型已从数亿参数(BERT)发展到数千亿参数(GPT-3),在下游任务上表现出显著的改进。业界也对大规模语言模型的扩展定律进行了广泛研究。在视觉领域也观察到了类似的趋势,社区正转向基于 Transformer 的模型(如 Vision TransformerMasked Auto Encoders)。很明显,单一模态——文本、图像、视频——已从近期的规模化发展中受益匪浅,各种框架也迅速适应以支持更大的模型。

与此同时,多模态在研究中的重要性日益增加,像图文检索、视觉问答、视觉对话和文本到图像生成等任务在现实世界应用中越来越受欢迎。训练大规模多模态模型是自然的下一步,我们已经看到该领域的一些努力,如 OpenAI 的 CLIP、Google 的 Parti 和 Meta 的 CM3

在这篇博客中,我们展示了一个案例研究,演示了如何使用 PyTorch Distributed 的技术将 FLAVA 扩展到 100 亿参数。FLAVA 是一个视觉和语言基础模型,可在 TorchMultimodal 中获取,它在单模态和多模态基准测试中都表现出有竞争力的性能。我们还在本博客中给出了相关的代码指针。运行扩展 FLAVA 的示例脚本的说明可以在这里找到。

扩展 FLAVA 概述

FLAVA 是一个基础多模态模型,由基于 Transformer 的图像和文本编码器,以及一个基于 Transformer 的多模态融合模块组成。它在单模态和多模态数据上使用多种损失函数进行预训练。这包括掩码语言、图像和多模态建模损失,这些损失要求模型从其上下文中重建原始输入(自监督学习)。它还使用基于正负对齐图文对的图文匹配损失,以及 CLIP 风格的对比损失。除了多模态任务(如图文检索)外,FLAVA 在单模态基准测试中也表现出有竞争力的性能(NLP 的 GLUE 任务和视觉的图像分类)。

原始的 FLAVA 模型约有 3.5 亿参数,其图像和文本编码器使用了(来自 Vision Transformer 论文的)ViT-B16 配置。多模态融合 Transformer 紧随单模态编码器之后,但层数是其一半。我们探索将每个编码器的规模增加到更大的 ViT 变体。

扩展的另一个方面是增加批量大小(batch size)的能力。FLAVA 利用了基于批内负样本的对比损失,这通常受益于大的批量大小(如这里所研究的)。最高的训练效率或吞吐量通常在接近 GPU 内存允许的最大可能批量大小时实现(另见实验部分)。

下表展示了我们实验过的不同模型配置。我们还在实验部分确定了每种配置能够容纳在内存中的最大批量大小。

近似模型参数隐藏层大小MLP 大小注意力头数单模态层数多模态层数模型大小 (fp32)
3.5亿 (原始)7683072121261.33GB
9亿102440961624123.48GB
18亿128051201632166.66GB
27亿1408614416402010.3GB
48亿1664819216482418.1GB
100亿20481024016644038GB

优化概述

PyTorch 提供了几种原生技术来高效地扩展模型。在以下部分,我们将介绍其中一些技术,并展示如何应用它们将 FLAVA 模型扩展到 100 亿个参数。

分布式数据并行 (Distributed Data Parallel)

分布式训练的一个常见起点是数据并行。数据并行在每个工作节点(GPU)上复制模型,并将数据集在这些工作节点之间进行分区。不同的工作节点并行处理不同的数据分区,并在模型权重更新前同步它们的梯度(通过 all-reduce)。下图展示了数据并行处理单个样本的流程(前向、后向和权重更新步骤)。

来源:https://engineering.fb.com/2021/07/15/open-source/fsdp/

PyTorch 提供了一个原生 API,即 DistributedDataParallel (DDP),用于实现数据并行。它可以作为一个模块包装器使用,如下所示。更多详情请参阅 PyTorch Distributed 的文档

from torchmultimodal.models.flava.model import flava_model_for_pretraining
import torch
import torch.distributed as dist

model = flava_model_for_pretraining().cuda()
# Initialize PyTorch Distributed process groups
# Please see https://pytorch.ac.cn/tutorials/intermediate/dist_tuto.html for details
dist.init_process_group(backend=”nccl”)
# Wrap model in DDP
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[torch.cuda.current_device()])

完全分片数据并行 (Fully Sharded Data Parallel)

训练应用的 GPU 内存使用量大致可分为模型输入、中间激活(用于梯度计算)、模型参数、梯度和优化器状态。扩展模型通常会增加这些元素的每一个。使用 DDP 扩展模型最终可能导致内存不足问题,因为它在所有工作节点上复制参数、梯度和优化器状态,单个 GPU 的内存会变得不足。

为了减少这种复制并节省 GPU 内存,我们可以将模型参数、梯度和优化器状态在所有工作节点上进行分片,每个工作节点只管理一个分片。这项技术由微软开发的 ZeRO-3 方法推广。PyTorch 对此方法的原生实现作为 FullyShardedDataParallel (FSDP) API 提供,在 PyTorch 1.12 中作为 beta 功能发布。在模块的前向和后向传播过程中,FSDP 会根据计算需要取消分片模型参数(使用 all-gather),并在计算后重新分片。它使用 reduce-scatter 集合操作来同步梯度,以确保分片后的梯度是全局平均的。FSDP 包装的模型的前向和后向传播流程详述如下。

来源:https://engineering.fb.com/2021/07/15/open-source/fsdp/

要使用 FSDP,模型的子模块需要用该 API 进行包装,以控制特定子模块何时进行分片或取消分片。FSDP 提供了一个自动包装 API(参见 auto_wrap_policy 参数),可以开箱即用,还提供了几种包装策略以及编写自定义策略的能力。

以下示例演示了如何用 FSDP 包装 FLAVA 模型。我们将自动包装策略指定为 `transformer_auto_wrap_policy`。这将把单个 Transformer 层(`TransformerEncoderLayer`)、图像 Transformer(`ImageTransformer`)、文本编码器(`BERTTextEncoder`)和多模态编码器(`FLAVATransformerWithoutEmbeddings`)包装为独立的 FSDP 单元。这使用了一种递归包装方法以实现高效的内存管理。例如,在单个 Transformer 层的前向或后向传播完成后,其参数会被丢弃,从而释放内存,减少峰值内存使用。

FSDP 还提供了许多可配置的选项来调整应用程序的性能。例如,在我们的用例中,我们演示了新的 `limit_all_gathers` 标志的使用,该标志可以防止过早地对模型参数进行 all-gather,从而减轻应用程序的内存压力。我们鼓励用户尝试这个标志,它有可能提高具有高活动内存使用率的应用程序的性能。

import torch
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from torchmultimodal.models.flava.model import flava_model_for_pretraining
from torchmultimodal.models.flava.text_encoder import BertTextEncoder
from torchmultimodal.models.flava.image_encoder import ImageTransformer
from torchmultimodal.models.flava.transformer import FLAVATransformerWithoutEmbeddings
from torchmultimodal.modules.layers.transformer import TransformerEncoderLayer

model = flava_model_for_pretraining().cuda()
dist.init_process_group(backend=”nccl”)

model = FSDP(
               model,
               device_id=torch.cuda.current_device(),
               auto_wrap_policy=partial(
                   transformer_auto_wrap_policy,
                   transformer_layer_cls={
                       TransformerEncoderLayer,
                       ImageTransformer,
                       BERTTextEncoder,
                       FLAVATransformerWithoutEmbeddings
                   },
               ),
               limit_all_gathers=True,
           )

激活检查点 (Activation Checkpointing)

如上所述,中间激活、模型参数、梯度和优化器状态共同构成了 GPU 的总内存使用量。FSDP 可以减少后三者导致的内存消耗,但不能减少激活所消耗的内存。激活所用的内存会随着批量大小或隐藏层数量的增加而增加。激活检查点是一种减少这种内存使用的技术,它通过在后向传播期间重新计算激活,而不是将它们保存在内存中,来处理特定的检查点模块。例如,我们观察到,通过对 27 亿参数模型应用激活检查点,前向传播后的峰值活动内存减少了约 4 倍。

PyTorch 提供了基于包装器的激活检查点 API。特别是,`checkpoint_wrapper` 允许用户用检查点包装单个模块,而 `apply_activation_checkpointing` 允许用户指定一个策略,用该策略来包装整个模块内的子模块。这两个 API 都可以应用于大多数模型,因为它们不需要对模型定义代码进行任何修改。但是,如果需要对检查点段进行更精细的控制,例如检查点模块内的特定函数,则可以利用函数式的 `torch.utils.checkpoint` API,尽管这需要修改模型代码。下面展示了将激活检查点包装器应用于单个 FLAVA Transformer 层(由 `TransformerEncoderLayer` 表示)的应用。有关激活检查点的详细描述,请参阅 PyTorch 文档中的描述。

from torchmultimodal.models.flava.model import flava_model_for_pretraining
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import apply_activation_checkpointing, checkpoint_wrapper, CheckpointImpl
from torchmultimodal.modules.layers.transformer import TransformerEncoderLayer

model = flava_model_for_pretraining()
checkpoint_tformer_layers_policy = lambda submodule: isinstance(submodule, TransformerEncoderLayer)

apply_activation_checkpointing(
               model,
               checkpoint_wrapper_fn=checkpoint_wrapper,
               check_fn=checkpoint_tformer_layers_policy,
           )

通过将 FLAVA Transformer 层与激活检查点包装起来,并如上所示将整个模型与 FSDP 包装起来,我们能够将 FLAVA 扩展到 100 亿参数。

实验

我们对上一节中不同优化对系统性能的影响进行了实证研究。对于我们所有的实验,我们使用一个配备 8 个 A100 40GB GPU 的单节点,并进行 1000 次迭代的预训练。所有运行都使用了 PyTorch 的自动混合精度,数据类型为 bfloat16。还启用了 TensorFloat32 格式以提高 A100 上的矩阵乘法性能。我们将吞吐量定义为每秒处理的平均项目数(文本或图像)(在测量吞吐量时,我们忽略了前 100 次迭代以考虑预热)。我们将训练至收敛及其对下游任务指标的影响作为未来研究的领域。

图 1 绘制了每种模型配置和优化的吞吐量,分别使用本地批量大小为 8 和单个节点上可能的最大批量大小。如果某个模型变体在某个优化下没有数据点,则表示该模型无法在单个节点上进行训练。

图 2 绘制了每种优化下每个工作节点可能的最大批量大小。我们观察到几点:

  1. 扩展模型规模:DDP 只能在一个节点上容纳 3.5 亿和 9 亿参数的模型。使用 FSDP,由于节省了内存,我们能够训练比 DDP 大约 3 倍的模型(即 18 亿和 27 亿的变体)。将激活检查点(AC)与 FSDP 结合使用,可以训练更大的模型,比 DDP 大约 10 倍(即 48 亿和 100 亿的变体)。
  2. 吞吐量
    • 对于较小的模型尺寸,在固定的批量大小为 8 的情况下,DDP 的吞吐量略高于或等于 FSDP,这可以由 FSDP 所需的额外通信来解释。对于 FSDP 和 AC 组合的情况,吞吐量最低。这是因为 AC 在后向传播期间重新运行了检查点的前向传播,用额外的计算换取了内存节省。然而,在 27 亿参数模型的情况下,FSDP + AC 的吞吐量实际上高于单独使用 FSDP。这是因为即使在批量大小为 8 的情况下,使用 FSDP 的 27 亿模型也已接近内存极限,触发了 CUDA malloc 重试,这会减慢训练速度。AC 有助于减轻内存压力,并避免了重试。
    • 对于 DDP 和 FSDP + AC,每个模型的吞吐量都随着批量大小的增加而增加。对于单独使用 FSDP,这对于较小的变体是成立的。然而,对于 18 亿和 27 亿参数的模型,我们观察到在增加批量大小时吞吐量会下降。如上所述,一个可能的原因是,在内存极限下,PyTorch 的 CUDA 内存管理可能需要重试 cudaMalloc 调用和/或运行昂贵的碎片整理步骤来寻找空闲的内存块以满足工作负载的内存需求,这可能导致训练减速。
    • 对于只能用 FSDP 训练的较大模型(18亿、27亿、48亿),吞吐量最高的设置是使用 FSDP + AC 并扩展到最大批量大小。对于 100 亿参数的模型,我们观察到较小和最大批量大小的吞吐量几乎相等。这可能与直觉相悖,因为 AC 导致计算量增加,而将批量大小最大化可能会因在 CUDA 内存极限下操作而导致昂贵的碎片整理操作。然而,对于这些大型模型,批量大小的增加足以掩盖这种开销。

图1:不同配置下的训练吞吐量

  1. 批量大小:单独使用 FSDP 相比 DDP 可以实现稍大的批量大小。使用 FSDP + AC 对于 3.5 亿参数模型可以实现比 DDP 大约 3 倍的批量大小,对于 9 亿参数模型则可以实现大约 5.5 倍。即使对于 100 亿参数的模型,最大批量大小也能达到约 20,这相当可观。这实质上使得可以用更少的 GPU 实现更大的全局批量大小,这对于对比学习任务尤其有用。

图2:不同配置下可能的最大本地批量大小

结论

随着世界向多模态基础模型发展,扩展模型参数和高效训练正成为一个焦点领域。PyTorch 生态系统旨在通过向研究社区提供不同的工具来加速该领域的创新,这些工具既用于训练也用于扩展多模态模型。通过 FLAVA,我们展示了一个扩展多模态理解模型的例子。未来,我们计划增加对其他类型模型的支持,如用于多模态生成的模型,并展示它们的扩展因子。我们还希望自动化许多这些扩展和内存节省技术(如分片和激活检查点),以减少用户为达到所需规模和最大训练吞吐量所需的实验量。

参考文献