跳转到主要内容
博客

在 TorchMultimodal 中使用 Pytorch Distributed 扩展多模态基础模型

引言

近年来,扩展模型规模已成为一个很有前景的研究领域。在自然语言处理领域,语言模型已从数亿参数(BERT)扩展到数千亿参数(GPT-3),并在下游任务中表现出显著的改进。大型语言模型的扩展定律在业界也得到了广泛研究。在视觉领域也观察到类似的趋势,社区正转向基于 Transformer 的模型(如Vision TransformerMasked Auto Encoders)。很明显,文本、图像、视频等单一模态都从最近的规模进步中受益匪浅,框架也迅速适应以容纳更大的模型。

与此同时,多模态在研究中变得越来越重要,图像-文本检索、视觉问答、视觉对话和文本到图像生成等任务在实际应用中越来越受欢迎。训练大规模多模态模型是自然的下一步,我们已经看到该领域的一些努力,例如 OpenAI 的 CLIP、Google 的 Parti 和 Meta 的 CM3

在这篇博客中,我们展示了一个案例研究,演示了如何使用 PyTorch Distributed 中的技术将 FLAVA 扩展到 100 亿个参数。FLAVA 是一种视觉和语言基础模型,可在 TorchMultimodal 中使用,在单模态和多模态基准测试中均表现出有竞争力的性能。我们还在本博客中提供了相关的代码指针。运行用于扩展 FLAVA 的示例脚本的说明可以在此处找到。

FLAVA 扩展概述

FLAVA 是一种基础多模态模型,由基于 Transformer 的图像和文本编码器组成,然后是一个基于 Transformer 的多模态融合模块。它使用各种损失在单模态和多模态数据上进行预训练。这包括掩码语言、图像和多模态建模损失,这些损失要求模型从其上下文(自监督学习)重建原始输入。它还使用图像文本匹配损失对对齐图像-文本对的正例和负例以及 CLIP 风格的对比损失进行训练。除了多模态任务(如图像-文本检索)之外,FLAVA 在单模态基准(NLP 的 GLUE 任务和视觉的图像分类)上也表现出有竞争力的性能。

原始 FLAVA 模型有大约 3.5 亿个参数,并使用 ViT-B16 配置(来自 Vision Transformer 论文)用于图像和文本编码器。多模态融合 Transformer 遵循单模态编码器,但层数减半。我们探索将每个编码器的大小增加到更大的 ViT 变体。

扩展的另一个方面是增加批次大小的能力。FLAVA 利用批次内负样本上的对比损失,这通常受益于大批次大小(如此处研究)。当在 GPU 内存允许的最大批次大小附近操作时,通常也能实现最大的训练效率或吞吐量(另请参见实验部分)。

下表显示了我们实验的不同模型配置。我们还在实验部分确定了每种配置能够适应内存的最大批次大小。

近似模型参数隐藏大小MLP 大小头数单模态层数多模态层数模型大小 (fp32)
3.5 亿(原始)7683072121261.33 GB
9 亿102440961624123.48 GB
18 亿128051201632166.66 GB
27 亿1408614416402010.3 GB
48 亿1664819216482418.1 GB
100 亿20481024016644038 GB

优化概述

PyTorch 提供了多种原生技术来有效地扩展模型。在接下来的章节中,我们将介绍其中一些技术,并展示如何将它们应用于将 FLAVA 模型扩展到 100 亿个参数。

分布式数据并行

分布式训练的常见起点是数据并行。数据并行将模型复制到每个 worker(GPU),并将数据集分区到 worker。不同的 worker 并行处理不同的数据分区,并在模型权重更新之前同步其梯度(通过 all reduce)。下图展示了数据并行处理单个示例的流程(前向、反向和权重更新步骤)

来源:https://engineering.fb.com/2021/07/15/open-source/fsdp/

PyTorch 提供了一个原生 API,DistributedDataParallel (DDP) 来启用数据并行,它可以作为模块包装器使用,如下所示。有关更多详细信息,请参阅 PyTorch 分布式文档

from torchmultimodal.models.flava.model import flava_model_for_pretraining
import torch
import torch.distributed as dist

model = flava_model_for_pretraining().cuda()
# Initialize PyTorch Distributed process groups
# Please see https://pytorch.ac.cn/tutorials/intermediate/dist_tuto.html for details
dist.init_process_group(backend=”nccl”)
# Wrap model in DDP
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[torch.cuda.current_device()])

完全分片数据并行

训练应用程序的 GPU 内存使用量大致可分为模型输入、中间激活(用于梯度计算)、模型参数、梯度和优化器状态。扩展模型通常会增加这些元素的每一个。使用 DDP 扩展模型最终可能导致内存不足问题,因为单个 GPU 的内存不足以容纳所有参数、梯度和优化器状态,因为它会在所有 worker 上复制这些内容。

为了减少这种复制并节省 GPU 内存,我们可以将模型参数、梯度和优化器状态分片到所有 worker 上,每个 worker 只管理一个分片。这种技术由微软开发的 ZeRO-3 方法推广开来。PyTorch 1.12 发布了这种方法的 PyTorch 原生实现,作为 FullyShardedDataParallel (FSDP) API 的 Beta 功能。在模块的前向和反向传播过程中,FSDP 会根据计算需要取消分片模型参数(使用 all-gather),并在计算后重新分片它们。它使用 reduce-scatter 集合操作同步梯度,以确保分片梯度在全球范围内平均。下面详细介绍了使用 FSDP 封装的模型的前向和反向传播流程

来源:https://engineering.fb.com/2021/07/15/open-source/fsdp/

要使用 FSDP,模型的子模块需要用 API 包装,以控制何时分片或取消分片特定子模块。FSDP 提供了一个可以开箱即用的自动包装 API(请参阅 auto_wrap_policy 参数)以及多种包装策略编写自己的策略的能力。

以下示例演示了使用 FSDP 封装 FLAVA 模型。我们将自动封装策略指定为 transformer_auto_wrap_policy。这将把单个 Transformer 层 (TransformerEncoderLayer)、图像 Transformer (ImageTransformer)、文本编码器 (BERTTextEncoder) 和多模态编码器 (FLAVATransformerWithoutEmbeddings) 封装为独立的 FSDP 单元。这使用了递归封装方法以实现高效内存管理。例如,在单个 Transformer 层的前向或反向传播完成后,其参数将被丢弃,从而释放内存,从而降低峰值内存使用量。

FSDP 还提供了许多可配置选项来调整应用程序的性能。例如,在我们的用例中,我们演示了使用新的 limit_all_gathers 标志,它可以防止过早地收集所有模型参数,从而减轻应用程序的内存压力。我们鼓励用户尝试使用此标志,它可能会提高具有高活动内存使用量的应用程序的性能。

import torch
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from torchmultimodal.models.flava.model import flava_model_for_pretraining
from torchmultimodal.models.flava.text_encoder import BertTextEncoder
from torchmultimodal.models.flava.image_encoder import ImageTransformer
from torchmultimodal.models.flava.transformer import FLAVATransformerWithoutEmbeddings
from torchmultimodal.modules.layers.transformer import TransformerEncoderLayer

model = flava_model_for_pretraining().cuda()
dist.init_process_group(backend=”nccl”)

model = FSDP(
               model,
               device_id=torch.cuda.current_device(),
               auto_wrap_policy=partial(
                   transformer_auto_wrap_policy,
                   transformer_layer_cls={
                       TransformerEncoderLayer,
                       ImageTransformer,
                       BERTTextEncoder,
                       FLAVATransformerWithoutEmbeddings
                   },
               ),
               limit_all_gathers=True,
           )

激活检查点

如上所述,中间激活、模型参数、梯度和优化器状态共同构成了整体 GPU 内存使用量。FSDP 可以减少后三者造成的内存消耗,但不能减少激活消耗的内存。激活使用的内存随着批次大小或隐藏层数量的增加而增加。激活检查点是一种通过在反向传播期间重新计算激活而不是将它们保留在特定检查点模块的内存中来减少内存使用量的技术。例如,我们观察到在将激活检查点应用于 2.7B 参数模型后,前向传播后的峰值活动内存减少了约 4 倍。

PyTorch 提供了一个基于包装器的激活检查点 API。特别是,checkpoint_wrapper 允许用户使用检查点包装单个模块,而 apply_activation_checkpointing 允许用户指定一个策略来包装整体模块中的模块。这两个 API 都可以应用于大多数模型,因为它们不需要对模型定义代码进行任何修改。但是,如果需要对检查点片段进行更精细的控制,例如检查点模块中的特定函数,则可以使用函数式 torch.utils.checkpoint API,尽管这需要修改模型代码。下面显示了将激活检查点包装器应用于单个 FLAVA Transformer 层(由 TransformerEncoderLayer 表示)。有关激活检查点的详细描述,请参阅 PyTorch 文档中的描述。

from torchmultimodal.models.flava.model import flava_model_for_pretraining
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import apply_activation_checkpointing, checkpoint_wrapper, CheckpointImpl
from torchmultimodal.modules.layers.transformer import TransformerEncoderLayer

model = flava_model_for_pretraining()
checkpoint_tformer_layers_policy = lambda submodule: isinstance(submodule, TransformerEncoderLayer)

apply_activation_checkpointing(
               model,
               checkpoint_wrapper_fn=checkpoint_wrapper,
               check_fn=checkpoint_tformer_layers_policy,
           )

结合使用,如上所示,使用激活检查点封装 FLAVA Transformer 层,并使用 FSDP 封装整个模型,我们能够将 FLAVA 扩展到 100 亿个参数。

实验

我们对上一节中不同优化对系统性能的影响进行了实证研究。对于我们所有的实验,我们使用一个带有 8 个 A100 40GB GPU 的单节点,并进行 1000 次迭代的预训练。所有运行都使用 PyTorch 的 自动混合精度 和 bfloat16 数据类型。TensorFloat32 格式也已启用,以提高 A100 上的矩阵乘法性能。我们将吞吐量定义为每秒处理的平均项目数(文本或图像)(我们在测量吞吐量时忽略前 100 次迭代以考虑预热)。我们将在未来研究训练收敛及其对下游任务指标的影响。

图 1 绘制了每种模型配置和优化的吞吐量,分别使用局部批次大小为 8 和单个节点上可能的最大批次大小。模型变体在优化中没有数据点表示该模型无法在单个节点上进行训练。

图 2 绘制了每种优化每个 worker 可能的最大批次大小。我们观察到一些现象:

  1. 模型大小扩展:DDP 只能在一个节点上容纳 3.5 亿和 9 亿参数的模型。通过 FSDP,由于节省了内存,我们能够训练比 DDP 大约 3 倍的模型(即 18 亿和 27 亿参数的变体)。将激活检查点 (AC) 与 FSDP 结合使用,可以训练更大的模型,比 DDP 大约 10 倍(即 48 亿和 100 亿参数的变体)。
  2. 吞吐量
    • 对于较小的模型尺寸,在批次大小保持为 8 的情况下,DDP 的吞吐量略高于或等于 FSDP,这可以通过 FSDP 所需的额外通信来解释。当 FSDP 和 AC 结合使用时,吞吐量最低。这是因为 AC 在反向传播过程中重新运行检查点的前向传播,以计算量换取内存节省。然而,在 27 亿参数模型的情况下,FSDP + AC 的吞吐量实际上高于单独使用 FSDP。这是因为即使在批次大小为 8 时,使用 FSDP 的 27 亿参数模型也接近内存限制,从而触发 CUDA malloc 重试,这往往会减慢训练速度。AC 有助于减轻内存压力,并避免重试。
    • 对于 DDP 和 FSDP + AC,每个模型的吞吐量都随着批次大小的增加而增加。对于单独的 FSDP,这适用于较小的变体。然而,对于 18 亿和 27 亿参数的模型,我们观察到在增加批次大小时吞吐量下降。造成这种情况的一个潜在原因(如上所述)是,在内存限制下,PyTorch 的 CUDA 内存管理可能不得不重试 cudaMalloc 调用和/或运行昂贵的碎片整理步骤以查找空闲内存块来处理工作负载的内存需求,这可能导致训练速度变慢。
    • 对于只能使用 FSDP 训练的较大模型(18 亿、27 亿、48 亿),实现最高吞吐量的设置是 FSDP + AC 扩展到最大批次大小。对于 100 亿参数模型,我们观察到较小批次大小和最大批次大小的吞吐量几乎相等。这可能与直觉相反,因为 AC 会增加计算量,并且最大批次大小可能会导致由于在 CUDA 内存限制下运行而产生昂贵的碎片整理操作。然而,对于这些大型模型,批次大小的增加足以掩盖这种开销。

图 1:不同配置的训练吞吐量

  1. 批次大小:单独的 FSDP 能够实现比 DDP 略高的批次大小。对于 3.5 亿参数模型,使用 FSDP + AC 可以实现比 DDP 大约 3 倍的批次大小,对于 9 亿参数模型,可以实现大约 5.5 倍的批次大小。即使对于 100 亿参数模型,最大批次大小约为 20,这已经相当不错了。这实质上允许使用更少的 GPU 实现更大的全局批次大小,这对于对比学习任务特别有用。

图 2:不同配置可能的最大局部批次大小

结论

随着世界迈向多模态基础模型,扩展模型参数和高效训练正成为关注的焦点。PyTorch 生态系统旨在通过为研究社区提供用于训练和扩展多模态模型的不同工具来加速该领域的创新。通过 FLAVA,我们提供了一个扩展多模态理解模型的示例。未来,我们计划增加对其他类型模型(如多模态生成模型)的支持,并展示它们的扩展因素。我们还希望自动化许多这些扩展和内存节省技术(如分片和激活检查点),以减少用户为实现所需规模和最大训练吞吐量所需的实验量。

参考文献