作者: Vaibhav Aggarwal, Mannat Singh, Anjali Sridhar, Yanghao Li, Shoubhik Debnath, Ronghang Hu, Will Feng, Xinlei Chen, Tingting Markstrum, Diana Liskovich, Anupam Bhatnagar, Chay Ryali, Haoqi Fan, Tete Xiao, Min Xu, Rahul Iyer, Christoph Feichtenhofer, Ross Girshick, Piotr Dollar, Aaron Adcock, Wan-Yen Lo, CK Luk

TL;DR: 我们演示了如何将 PyTorch 与 FairScale 的 FullyShardedDataParallel (FSDP) API 结合使用来编写大型视觉 Transformer 模型。我们讨论了在 GPU 集群上扩展和优化这些模型的技术。本次平台扩展工作的目标是支持大规模研究。本博客不讨论模型准确性、新的模型架构或新的训练秘籍。

1. 引言

最新的视觉研究 [1, 2] 表明模型扩展是一个有前景的研究方向。在本项目中,我们的目标是使我们的平台能够训练大型视觉 Transformer (ViT) [3] 模型。我们介绍了在 FAIR 视觉平台中将最大可训练 ViT 从 10 亿参数扩展到 1200 亿参数的工作。我们使用 PyTorch 编写了 ViT,并利用其对 GPU 集群上大规模分布式训练的支持。

在本博客的其余部分,我们将首先讨论主要挑战,即可扩展性优化数值稳定性。然后我们将讨论如何通过数据和模型并行性自动混合精度核函数融合bfloat16 等技术来解决这些问题。最后,我们将展示我们的结果并总结。

2. 主要挑战

2.1 可扩展性

关键的可扩展性挑战是如何有效地将模型的操作和状态分片到多个 GPU 上。一个 1000 亿参数的模型仅参数就需要约 200GB 内存(假设采用 fp16 表示)。因此,无法将模型完全放入单个 GPU(A100 最大只有 80GB 内存)。我们需要某种方式来有效地将模型的数据(输入、参数、激活和优化器状态)分片到多个 GPU 上。

这个问题的另一方面是,在扩展的同时不显著改变训练秘籍。例如,某些表示学习秘籍使用高达 4096 的全局批次大小,超过此值会开始出现精度下降。如果没有某种形式的张量或流水线并行,我们就无法扩展到超过 4096 个 GPU。

2.2 优化

关键的优化挑战是在扩展模型参数数量和浮点运算次数的同时,保持高 GPU 利用率。当我们把模型扩展到 teraflops 及更高时,我们的软件堆栈开始出现主要瓶颈,这会超线性地增加训练时间并降低加速器利用率。我们仅运行一个实验就需要数百或数千个 GPU。提高加速器利用率可以显著降低成本并改善集群利用率。这使我们能够资助更多项目并并行运行更多实验。

2.3 数值稳定性

关键的稳定性挑战是在大规模训练时避免数值不稳定和发散。我们在实验中凭经验观察到,当我们扩大模型大小、数据量、批次大小、学习率等时,训练不稳定性变得严重且难以处理。视觉 Transformer 即使在较低的参数阈值下也面临训练不稳定性。例如,我们发现在不使用强数据增强的情况下,即使是 ViT-H(仅有 6.3 亿参数)在混合精度模式下也难以训练。我们需要研究模型特性和训练秘籍,以确保模型能够稳定训练并收敛。

3. 我们的解决方案

图 1 描绘了我们针对每个挑战的解决方案。

3.1 利用数据并行和模型并行解决扩展挑战

我们应用了各种形式的数据并行和模型并行,以实现在 GPU 内存中适应非常大的模型。

我们使用基于 PyTorch 的 FairScale 的 FullyShardedDataParallel (FSDP) API [4],将参数、梯度和优化器状态分片到多个 GPU 上,从而减少每个 GPU 的内存占用。此过程包含以下三个步骤:

  • 步骤 1:我们将整个模型封装在一个 FSDP 实例中。这会在前向传播结束时分片模型参数,并在前向传播开始时收集参数。这使我们能够将参数量从 15 亿扩展到 45 亿,实现了约 3 倍的扩展。

  • 步骤 2:我们尝试将单个模型层封装在独立的 FSDP 实例中。这种嵌套封装通过分片和收集单个模型层而非整个模型的参数,进一步减少了内存占用。在这种模式下,峰值内存消耗取决于 GPU 内存中单独封装的 Transformer 块,而不是整个模型。

  • 步骤 3:我们使用激活检查点来减少激活的内存消耗。它在前向传播过程中保存输入张量并丢弃中间激活张量。这些张量在后向传播过程中重新计算。

此外,我们尝试了模型并行技术,例如流水线并行 [5],这使我们能够在不增加批次大小的情况下扩展到更多 GPU。

3.2 利用高级 AMP 和核函数融合解决优化挑战

高级 AMP

自动混合精度 (AMP) [6] 训练是指使用比 FP32 或默认精度低的位精度来训练模型,但仍能保持准确性。我们尝试了以下三个级别的 AMP:

  • AMP O1:这是指混合精度训练,其中权重采用 FP32,部分运算采用 FP16。在 AMP O1 中,可能影响精度的运算仍保留在 FP32 中,不会自动转换为 FP16。

  • AMP O2:这是指混合精度训练,但采用 FP16 的权重和运算比 O1 更多。权重不会隐式保留在 FP32 中,而是转换为 FP16。主权重的一个副本以 FP32 精度维护,供优化器使用。如果希望归一化层权重保持 FP32,则需要显式使用层封装来确保。

  • 全 FP16:这是指完全采用 FP16 进行训练,其中权重和运算都采用 FP16。由于收敛问题,完全启用 FP16 训练具有挑战性。

我们发现,使用 AMP O2 并将 LayerNorm 封装在 FP32 中,可以在不牺牲准确性的情况下获得最佳性能。

核函数融合

  • 为了减少 GPU 核函数启动开销并增加 GPU 工作粒度,我们使用 xformers 库 [7] 尝试了核函数融合,包括融合 dropout 和融合 layer-norm。

3.3 通过研究运算的数值稳定性和训练秘籍解决稳定性挑战

总体采用 BFloat16,但 LayerNorm 采用 FP32

bfloat16 (BF16) [8] 浮点格式提供了与 FP32 相同的动态范围,同时内存占用与 FP16 相同。我们发现,使用与 FP32 相同的超参数集,可以在 BF16 格式下训练模型,无需特别的参数调整。尽管如此,我们发现为了使训练收敛,需要将 LayerNorm 保持在 FP32 模式下。

3.4 最终训练秘籍

最终训练秘籍总结。

  1. 将外部模型封装在 FSDP 实例中。在前向传播后启用参数分片。
  2. 将单个 ViT 块封装,启用激活检查点、嵌套 FSDP 封装和参数展平。
  3. 启用混合精度模式 (AMP O2),采用 bfloat16 表示。保持优化器状态在 FP32 精度,以增强数值稳定性。
  4. 将 LayerNorm 等归一化层封装在 FP32 中,以获得更好的数值稳定性。
  5. 通过保持矩阵维度为 8 的倍数来最大化 Nvidia TensorCore 的利用率。更多详情请查阅 Nvidia Tensor Core 性能指南

4. 结果

在本节中,我们将展示 ViT 在三种任务上的扩展结果:(1)图像分类,(2)目标检测,(3)视频理解。我们的关键结果是,在应用了讨论的扩展和优化技术后,我们能够在这些视觉任务上训练大型 ViT 骨干网络。这使得视觉研究能够以更大的规模进行。 我们将模型训练至收敛,以验证即使应用了所有优化,我们仍能保持当前的基线。图 2、3、4 的共同趋势是,我们能够在 128 个 A100 GPU 上训练参数量高达 250 亿的模型,每个 epoch 的时间少于 4 小时。参数量为 600 亿和 1200 亿的模型训练速度相对较慢。

图 2 显示了图像分类的扩展结果。它绘制了在使用 128 个 A100-80GB GPU 训练不同模型尺寸的 ViT 在 ImageNet 上的每个 epoch 时间。

图 2:图像分类扩展结果。

图 3 显示了目标检测的扩展结果。它绘制了在使用 128 个 A100-80GB GPU 训练不同 ViT 骨干网络的 ViTDet [9] 在 COCO 上的每个 epoch 时间。

图 3:目标检测扩展结果。

图 4 显示了视频理解的扩展结果。它绘制了在使用 128 个 V100 (32 GB) GPU(采用 FP32 精度)训练 MViTv2 [10] 模型在 Kinetics 400 [11] 上的每个 epoch 时间。

图 4:视频理解扩展结果。

图 5 显示了使用图 2 中的 ViT-H 模型在 8 个 A100-40GB GPU 上的优化结果。使用了三个版本:(1)基线使用 PyTorch 的 DDP [12] 结合 AMP O1,(2)FSDP + AMP-O2 + 其他优化,以及(3)FSDP + FP16 + 其他优化。这些优化总共将训练速度提高了最多 2.2 倍。

图 5:各种优化带来的训练加速。

5. 总结

我们演示了如何将 PyTorch 与 FairScale 的 FullyShardedDataParallel (FSDP) API 结合使用来编写大型视觉 Transformer 模型。我们讨论了在 GPU 集群上扩展和优化这些模型的技术。我们希望本文能激励其他人使用 PyTorch 及其生态系统开发大规模机器学习模型。

参考文献

[1] Masked Autoencoders:可扩展的视觉学习器

[2] 重访弱监督预训练的视觉感知模型

[3] 图像即 16x16 词:用于大规模图像识别的 Transformer

[4] fairscale.nn.FullyShardedDataParallel

[5] PyTorch 中的流水线并行

[6] PyTorch 中的自动混合精度 (AMP)

[7] xformers

[8] bfloat16 数值格式

[9] 探索用于目标检测的纯视觉 Transformer 骨干网络

[10] MViTv2:改进的多尺度视觉 Transformer,用于分类和检测

[11] https://www.deepmind.com/open-source/kinetics

[12] 分布式数据并行 (DDP) 入门