TL;DR:我们展示了如何使用 PyTorch 配合 FairScale 的 FullyShardedDataParallel (FSDP) API 来编写大型视觉 Transformer 模型。我们讨论了在 GPU 集群上扩展和优化这些模型的技术。此次平台扩展工作的目标是支持大规模研究。本博客不讨论模型准确性、新模型架构或新的训练配方。
1. 引言
最新的视觉研究 [1, 2] 表明,模型缩放是一个很有前景的研究方向。在这个项目中,我们的目标是使我们的平台能够训练大规模的视觉 Transformer (ViT) [3] 模型。我们展示了在 FAIR 视觉平台中将可训练的最大 ViT 从 10 亿参数扩展到 1200 亿参数的工作。我们使用 PyTorch 编写了 ViT,并利用其对 GPU 集群上大规模分布式训练的支持。
在本文的其余部分,我们将首先讨论主要挑战,即可扩展性、优化和数值稳定性。然后,我们将讨论如何通过数据和模型并行、自动混合精度、算子融合和bfloat16等技术来解决这些挑战。最后,我们将展示我们的结果并得出结论。
2. 主要挑战
2.1 可扩展性
可扩展性的核心挑战在于如何高效地在多个 GPU 之间切分模型的算子和状态。假设使用 fp16 表示,一个 1000 亿参数的模型仅参数本身就需要约 200GB 的内存。因此,不可能将模型放入单个 GPU 中(A100 最多只有 80GB 内存)。因此,我们需要某种方法在多个 GPU 之间高效地切分模型数据(输入、参数、激活值和优化器状态)。
该问题的另一个方面是在不显著改变训练配方的情况下进行扩展。例如,某些表征学习配方使用的全局批大小最高可达 4096,超过此数值后,我们开始观察到准确性下降。如果不使用某种形式的张量并行或流水线并行,我们无法扩展到超过 4096 个 GPU。
2.2 优化
优化的核心挑战是在扩展模型参数和浮点运算量 (flops) 时保持高 GPU 利用率。当我们将模型扩展到每秒万亿次浮点运算 (teraflops) 及以上时,我们的软件栈开始遇到主要瓶颈,这会以超线性方式增加训练时间并降低加速器利用率。我们运行单个实验就需要数百或数千个 GPU。加速器利用率的提升可以显著降低成本并改善机群利用率。这使我们能够资助更多的项目并并行运行更多的实验。
2.3 数值稳定性
稳定性的核心挑战是避免在大规模训练时出现数值不稳定和发散。我们在实验中凭经验观察到,当我们扩大模型规模、数据、批大小、学习率等时,训练不稳定性会变得严重且难以处理。视觉 Transformer 即使在较低的参数阈值下也面临训练不稳定性。例如,我们发现即使在没有使用强数据增强的情况下,以混合精度模式训练 ViT-H(仅 6.3 亿参数)也具有挑战性。我们需要研究模型属性和训练配方,以确保模型能够稳定训练并收敛。
3. 我们的解决方案
图 1 描绘了我们针对每一项挑战的解决方案。

3.1 通过数据并行和模型并行解决扩展挑战
我们应用了各种形式的数据和模型并行,以便将超大型模型放入 GPU 内存中。
我们使用了基于 PyTorch 的 FairScale 的 FullyShardedDataParallel (FSDP) API [4],在多个 GPU 之间切分参数、梯度和优化器状态,从而减少了每个 GPU 的内存占用。该过程包含以下三个步骤:
- 第 1 步:我们将整个模型包装在一个 FSDP 实例中。这会在前向传播结束时切分模型参数,并在前向传播开始时收集参数。这使我们能够将参数量从 15 亿扩展到 45 亿,实现了约 3 倍的扩展。
- 第 2 步:我们尝试将单个模型层包装在单独的 FSDP 实例中。这种嵌套包装通过切分和收集单个模型层的参数而不是整个模型,进一步减少了内存占用。在这种模式下,峰值内存由 GPU 内存中单个包装好的 Transformer 块决定,而不是整个模型。
- 第 3 步:我们使用 激活检查点 (activation-checkpoint) 来减少激活值的内存消耗。它在前向传播期间保存输入张量并丢弃中间激活张量,这些张量在反向传播期间重新计算。
此外,我们还尝试了模型并行技术,如流水线并行 [5],这使我们能够在不增加批大小的情况下扩展到更多的 GPU。
3.2 通过高级 AMP 和算子融合解决优化挑战
高级 AMP
自动混合精度 (AMP) [6] 训练是指使用比 FP32 或默认精度更低的位宽进行模型训练,同时保持准确性。我们尝试了以下三种级别的 AMP:
- AMP O1:这指的是混合精度训练,其中权重以 FP32 存储,部分运算以 FP16 进行。使用 AMP O1 时,可能影响准确性的运算保留在 FP32 中,不会自动转换为 FP16。
- AMP O2:这指的是混合精度训练,但相比 O1,更多的权重和运算以 FP16 进行。权重不会隐式保留在 FP32 中,而是被转换为 FP16。优化器使用的 FP32 主权重副本会被保留。如果我们希望归一化层权重保持在 FP32,则需要显式使用层包装来确保这一点。
- 全 FP16:这指的是在全 FP16 中进行训练,其中权重和运算均以 FP16 进行。由于收敛问题,全 FP16 训练很难实现。
我们发现,使用 FP32 层归一化 (LayerNorm) 包装的 AMP O2 在不牺牲准确性的情况下提供了最佳性能。
算子融合
- 为了减少 GPU 算子启动开销并提高 GPU 工作粒度,我们使用 xformers 库 [7] 尝试了算子融合,包括融合 Dropout 和融合 LayerNorm。
3.3 通过研究算子数值稳定性和训练配方解决稳定性挑战
BFloat16(常规)配合 FP32 LayerNorm
bfloat16 (BF16) [8] 浮点格式提供与 FP32 相同的动态范围,且内存占用与 FP16 相同。我们发现可以使用与 FP32 相同的一组超参数在 BF16 格式中训练模型,无需进行特殊的参数调整。尽管如此,我们发现需要将 LayerNorm 保持在 FP32 模式下,训练才能收敛。
3.4 最终训练配方
最终训练配方的总结:
- 将外部模型包装在 FSDP 实例中。在前向传播后启用参数切分。
- 使用激活检查点、嵌套 FSDP 包装和参数平铺来包装单个 ViT 块。
- 启用具有 bfloat16 表示的混合精度模式 (AMP O2)。将优化器状态保持在 FP32 精度以增强数值稳定性。
- 将 LayerNorm 等归一化层包装在 FP32 中,以实现更好的数值稳定性。
- 通过保持矩阵维度为 8 的倍数,最大化 Nvidia TensorCore 的利用率。更多详情请查看 Nvidia Tensor Core 性能指南。
4. 结果
在本节中,我们展示了 ViT 在三类任务上的扩展结果:(1) 图像分类,(2) 目标检测,(3) 视频理解。我们的核心成果是,在应用了所讨论的扩展和优化技术后,我们能够训练跨这些视觉任务的大规模 ViT 主干网络。这使得更大规模的视觉研究成为可能。我们训练模型至收敛,以验证即使在所有优化措施下,我们仍能保持当前的基准性能。图 2、3、4 的共同趋势是,我们能够在 128 个 A100 GPU 上以不到 4 小时的 epoch 时间训练高达 250 亿参数的模型。600 亿和 1200 亿参数模型的训练速度相对较慢。
图 2 显示了图像分类的扩展结果。它绘制了使用 128 个 A100-80GB GPU 训练不同模型大小的 ViT 在 ImageNet 上的 epoch 时间。

图 2:图像分类扩展结果。
图 3 显示了目标检测的扩展结果。它绘制了使用 128 个 A100-80GB GPU 在 COCO 上训练具有不同 ViT 主干的 ViTDet [9] 的 epoch 时间。

图 3:目标检测扩展结果。
图 4 显示了视频理解的扩展结果。它绘制了使用 128 个 V100 (32 GB) GPU 在 FP32 下训练 MViTv2 [10] 模型在 Kinetics 400 [11] 上的 epoch 时间。

图 4:视频理解扩展结果。
图 5 显示了图 2 中 ViT-H 模型在 8 个 A100-40GB GPU 上的优化结果。使用了三个版本:(1) 使用 PyTorch DDP [12] 配合 AMP O1 的基准,(2) FSDP + AMP-O2 + 其他优化,以及 (3) FSDP + FP16 + 其他优化。这些优化总共使训练速度提升了高达 2.2 倍。

图 5:各种优化带来的训练加速。
5. 结束语
我们展示了使用 PyTorch 配合 FairScale 的 FullyShardedDataParallel (FSDP) API 来编写大型视觉 Transformer 模型的方法。我们讨论了在 GPU 集群上扩展和优化这些模型的技术。我们希望本文能够激励更多开发者使用 PyTorch 及其生态系统来开发大规模机器学习模型。
参考资料
[1] Masked Autoencoders Are Scalable Vision Learners(掩码自动编码器是可扩展的视觉学习器)
[2] Revisiting Weakly Supervised Pre-Training of Visual Perception Models(重温视觉感知模型的弱监督预训练)
[4] fairscale.nn.FullyShardedDataParallel
[5] PyTorch 中的流水线并行
[7] xformers
[8] bfloat16 数值格式
[9] Exploring Plain Vision Transformer Backbones for Object Detection(探索用于目标检测的朴素视觉 Transformer 主干)
[11] https://www.deepmind.com/open-source/kinetics
[12] 分布式数据并行 (DDP) 入门