简介
近年来,扩展模型规模已成为一个充满前景的研究领域。在自然语言处理(NLP)领域,语言模型已从数亿参数(BERT)发展到数千亿参数(GPT-3),并在下游任务中表现出显著的性能提升。大型语言模型的扩展定律 (Scaling laws) 也已在业界得到了广泛研究。在视觉领域也可以观察到类似的趋势,社区正在转向基于 Transformer 的模型(如 Vision Transformer、Masked Auto Encoders)。显而易见,文本、图像、视频等单一模态都极大地受益于近期在模型规模上的进步,而各种框架也迅速调整以适应更大的模型。
与此同时,多模态在研究中的重要性日益凸显,图像-文本检索、视觉问答、视觉对话和文生图等任务在现实世界应用中正获得广泛关注。训练大规模多模态模型是自然而然的下一步,我们已经看到该领域涌现出多项成果,例如 OpenAI 的 CLIP、Google 的 Parti 以及 Meta 的 CM3。
在本篇博客中,我们展示了一个案例研究,说明如何利用 PyTorch Distributed 的技术将 FLAVA 扩展到 10B 参数。FLAVA 是一个视觉和语言基础模型,可在 TorchMultimodal 中获取,该模型在单模态和多模态基准测试中均展现出了极具竞争力的性能。我们还在博客中提供了相关的代码指南。运行用于扩展 FLAVA 的示例脚本的说明可点击此处查看。
FLAVA 扩展概述
FLAVA 是一种多模态基础模型,由基于 Transformer 的图像和文本编码器,以及随后的基于 Transformer 的多模态融合模块组成。它在单模态和多模态数据上使用多种损失函数进行预训练。这包括掩码语言、掩码图像和多模态建模损失,这些损失要求模型根据上下文重构原始输入(自监督学习)。它还使用基于对齐的图像-文本对正负样本的图像文本匹配损失,以及 CLIP 风格的对比损失。除了多模态任务(如图像-文本检索)外,FLAVA 在单模态基准测试(NLP 的 GLUE 任务和视觉的图像分类)中也展示了具有竞争力的性能。

原始的 FLAVA 模型拥有约 3.5 亿(350M)参数,图像和文本编码器采用了 ViT-B16 配置(出自 Vision Transformer 论文)。多模态融合 Transformer 跟随在单模态编码器之后,但层数减半。我们探索了将每个编码器的大小增加到更大的 ViT 变体。
扩展的另一个方面是增加批量(batch size)的能力。FLAVA 使用批内负样本的对比损失,这通常受益于较大的批量(如此处所研究)。此外,当在 GPU 可用显存允许的最大批量附近运行时,通常也能实现最高的训练效率或吞吐量(另请参阅实验部分)。
下表展示了我们实验的不同模型配置。我们在实验部分确定了每种配置在内存中能够容纳的最大批量。
| 模型参数量近似值 | 隐藏层大小 (Hidden size) | MLP 大小 | 注意力头数 (Heads) | 单模态层数 | 多模态层数 | 模型大小 (fp32) |
|---|---|---|---|---|---|---|
| 350M (原始) | 768 | 3072 | 12 | 12 | 6 | 1.33GB |
| 900M | 1024 | 4096 | 16 | 24 | 12 | 3.48GB |
| 1.8B | 1280 | 5120 | 16 | 32 | 16 | 6.66GB |
| 2.7B | 1408 | 6144 | 16 | 40 | 20 | 10.3GB |
| 4.8B | 1664 | 8192 | 16 | 48 | 24 | 18.1GB |
| 10B | 2048 | 10240 | 16 | 64 | 40 | 38GB |
优化概述
PyTorch 提供了多种原生技术来高效扩展模型。在接下来的章节中,我们将介绍其中一些技术,并展示如何将它们应用于将 FLAVA 模型扩展至 100 亿(10B)参数。
分布式数据并行 (Distributed Data Parallel)
分布式训练的一个常见起点是数据并行。数据并行会在每个工作节点(GPU)上复制模型,并将数据集划分到各个工作节点上。不同的工作节点并行处理不同的数据分区,并在更新模型权重之前同步它们的梯度(通过 all-reduce)。下图展示了数据并行处理单个示例的流程(前向传播、反向传播和权重更新步骤)。

来源:https://engineering.fb.com/2021/07/15/open-source/fsdp/
PyTorch 提供了一个原生 API DistributedDataParallel (DDP) 来实现数据并行,它可以作为模块包装器使用,如下所示。请参阅 PyTorch Distributed 文档了解更多详细信息。
from torchmultimodal.models.flava.model import flava_model_for_pretraining
import torch
import torch.distributed as dist
model = flava_model_for_pretraining().cuda()
# Initialize PyTorch Distributed process groups
# Please see https://pytorch.ac.cn/tutorials/intermediate/dist_tuto.html for details
dist.init_process_group(backend=”nccl”)
# Wrap model in DDP
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[torch.cuda.current_device()])
完全分片数据并行 (Fully Sharded Data Parallel)
训练应用程序的 GPU 内存使用量大致可以分解为:模型输入、中间激活值(梯度计算所需)、模型参数、梯度和优化器状态。扩展模型通常会增加这些元素中的每一个。使用 DDP 扩展模型时,如果单个 GPU 的内存不足,由于它在所有工作节点上复制参数、梯度和优化器状态,最终可能会导致内存溢出(OOM)问题。
为了减少这种复制并节省 GPU 内存,我们可以将模型参数、梯度和优化器状态分片到所有工作节点上,每个工作节点仅管理一个分片。这种技术由微软开发的 ZeRO-3 方法推广。PyTorch 原生实现了一种名为 FullyShardedDataParallel (FSDP) 的 API,该 API 于 PyTorch 1.12 中作为测试功能发布。在模块的前向和反向传播过程中,FSDP 会根据计算需要取消模型参数的分片(使用 all-gather),并在计算后重新分片。它使用 reduce-scatter 集合通信来同步梯度,以确保分片梯度在全局范围内取平均值。FSDP 包装模型的具体前后向传播流程如下所述。

来源:https://engineering.fb.com/2021/07/15/open-source/fsdp/
要使用 FSDP,需要使用该 API 包装模型的子模块,以控制何时对特定子模块进行分片或取消分片。FSDP 提供了一个自动包装 API(请参见 auto_wrap_policy 参数),既可以直接使用,也提供了多种 包装策略,并支持编写您自己的策略。
以下示例演示了如何用 FSDP 包装 FLAVA 模型。我们将自动包装策略指定为 transformer_auto_wrap_policy。这将把单个 Transformer 层(TransformerEncoderLayer)、图像 Transformer(ImageTransformer)、文本编码器(BERTTextEncoder)和多模态编码器(FLAVATransformerWithoutEmbeddings)包装为独立的 FSDP 单元。这采用了递归包装方法以实现高效的内存管理。例如,在单个 Transformer 层的前向或反向传播完成后,其参数会被丢弃,从而释放内存,减少峰值内存使用。
FSDP 还提供了许多可配置的选项来调整应用程序的性能。例如,在我们的用例中,我们说明了新的 limit_all_gathers 标志的使用,它防止过早地 all-gather 模型参数,从而减轻应用程序的内存压力。我们鼓励用户尝试使用此标志,这可能会提高内存占用较高的应用程序的性能。
import torch
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from torchmultimodal.models.flava.model import flava_model_for_pretraining
from torchmultimodal.models.flava.text_encoder import BertTextEncoder
from torchmultimodal.models.flava.image_encoder import ImageTransformer
from torchmultimodal.models.flava.transformer import FLAVATransformerWithoutEmbeddings
from torchmultimodal.modules.layers.transformer import TransformerEncoderLayer
model = flava_model_for_pretraining().cuda()
dist.init_process_group(backend=”nccl”)
model = FSDP(
model,
device_id=torch.cuda.current_device(),
auto_wrap_policy=partial(
transformer_auto_wrap_policy,
transformer_layer_cls={
TransformerEncoderLayer,
ImageTransformer,
BERTTextEncoder,
FLAVATransformerWithoutEmbeddings
},
),
limit_all_gathers=True,
)
激活检查点 (Activation Checkpointing)
如上所述,中间激活值、模型参数、梯度和优化器状态构成了整体 GPU 内存使用量。FSDP 可以减少后三者导致的内存消耗,但不会减少激活值所占用的内存。激活值占用的内存随着批量或隐藏层数量的增加而增加。激活检查点是一种通过在反向传播过程中重新计算激活值(而不是将它们保存在内存中)来减少内存占用的技术。例如,我们观察到在 2.7B 参数模型上应用激活检查点后,前向传播后的峰值活跃内存减少了约 4 倍。
PyTorch 提供了基于包装器的激活检查点 API。具体而言,checkpoint_wrapper 允许用户使用检查点包装单个模块,而 apply_activation_checkpointing 允许用户指定一种策略,在整体模块内包装特定模块。这两个 API 都可以应用于大多数模型,因为它们不需要修改模型定义代码。但是,如果需要对检查点片段进行更细粒度的控制(例如检查点模块内的特定函数),可以利用函数式 torch.utils.checkpoint API,但这需要修改模型代码。下面展示了将激活检查点包装器应用于单个 FLAVA Transformer 层(表示为 TransformerEncoderLayer)。有关激活检查点的详细描述,请参阅 PyTorch 文档中的描述。
from torchmultimodal.models.flava.model import flava_model_for_pretraining
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import apply_activation_checkpointing, checkpoint_wrapper, CheckpointImpl
from torchmultimodal.modules.layers.transformer import TransformerEncoderLayer
model = flava_model_for_pretraining()
checkpoint_tformer_layers_policy = lambda submodule: isinstance(submodule, TransformerEncoderLayer)
apply_activation_checkpointing(
model,
checkpoint_wrapper_fn=checkpoint_wrapper,
check_fn=checkpoint_tformer_layers_policy,
)
综合使用上述方法:用激活检查点包装 FLAVA Transformer 层,并用 FSDP 包装整体模型,我们能够将 FLAVA 扩展到 10B 参数。
实验
我们针对上一节中不同优化技术对系统性能的影响进行了实证研究。对于所有实验,我们使用包含 8 个 A100 40GB GPU 的单个节点,并运行 1000 次迭代的预训练。所有运行均使用了 PyTorch 的 自动混合精度 (AMP) 以及 bfloat16 数据类型。同时启用了 TensorFloat32 格式以提高 A100 上的矩阵乘法性能。我们将吞吐量定义为每秒处理的平均项目数(文本或图像)(我们在计算吞吐量时忽略了前 100 次迭代,以考虑热身因素)。我们将训练至收敛及其对下游任务指标的影响留作未来研究的领域。
图 1 绘制了每种模型配置和优化技术的吞吐量,分别以 8 的局部批量和节点上可能的最大批量进行绘制。如果某个模型变体在某种优化下缺失数据点,则表示该模型无法在单节点上进行训练。
图 2 绘制了每种优化技术下每个工作节点可能的最大批量。我们观察到几点:
- 扩展模型大小:DDP 只能在节点上容纳 350M 和 900M 的模型。使用 FSDP,由于内存节省,我们能够训练比 DDP 大约 3 倍的模型(即 1.8B 和 2.7B 变体)。将激活检查点 (AC) 与 FSDP 结合,能够训练更大的模型,规模可达 DDP 的约 10 倍(即 4.8B 和 10B 变体)。
- 吞吐量
- 对于较小的模型规模,在 8 的固定批量下,DDP 的吞吐量略高于或等于 FSDP,这可以通过 FSDP 所需的额外通信来解释。对于 FSDP 和 AC 的组合,吞吐量最低。这是因为 AC 在反向传播期间重新运行了检查点前向传播,以额外的计算代价换取了内存节省。然而,在 2.7B 模型的情况下,FSDP + AC 的吞吐量实际上比单纯使用 FSDP 更高。这是因为 2.7B 模型在 FSDP 下即使在批量为 8 时也接近内存极限,触发了 CUDA malloc 重试,这往往会减慢训练速度。AC 有助于减轻内存压力,从而避免了重试。
- 对于 DDP 和 FSDP + AC,吞吐量随着每个模型批量的增加而增加。对于单纯的 FSDP,这对于较小的变体是成立的。然而,对于 1.8B 和 2.7B 参数的模型,我们观察到在增加批量时吞吐量出现了下降。如上所述,一个潜在的原因是:在内存极限时,PyTorch 的 CUDA 内存管理可能必须重试 cudaMalloc 调用和/或运行昂贵的碎片整理步骤,以查找空闲内存块来处理工作负载的内存需求,这可能导致训练减速。
- 对于只能使用 FSDP 训练的大型模型(1.8B、2.7B、4.8B),实现最高吞吐量的设置是 FSDP + AC 扩展至最大批量。对于 10B,我们观察到较小批量和最大批量下的吞吐量几乎相等。这可能违背直觉,因为 AC 导致计算增加,并且最大化批量可能会因在 CUDA 内存极限下运行而导致昂贵的碎片整理操作。然而,对于这些大型模型,批量增加的幅度足以掩盖这种开销。

图 1:不同配置的训练吞吐量
- 批量大小:与 DDP 相比,单独使用 FSDP 可实现略高的批量。对于 350M 参数模型,使用 FSDP + AC 可实现约为 DDP 3 倍的批量;对于 900M 参数模型,可实现约 5.5 倍。即使对于 10B 模型,最大批量约为 20,这已经相当不错了。这实质上可以使用更少的 GPU 实现更大的全局批量,这对对比学习任务特别有用。

图 2:不同配置下可能的最大局部批量
结论
随着全球向多模态基础模型迈进,扩展模型参数和高效训练正成为关注的焦点。PyTorch 生态系统旨在通过为研究社区提供不同的工具(无论是用于训练还是扩展多模态模型)来加速该领域的创新。通过 FLAVA,我们提供了一个扩展多模态理解模型的示例。未来,我们计划增加对其他模型类型的支持,例如多模态生成模型,并展示它们的扩展因子。我们还希望实现许多此类扩展和内存节省技术(如分片和激活检查点)的自动化,以减少用户为达到预期规模和最大训练吞吐量所需的实验量。