作者:Syed Ahmed, Christian Sarofeen, Mike Ruberry, Eddie Yan, Natalia Gimelshein, Michael Carilli, Szymon Migacz, Piotr Bialecki, Paulius Micikevicius, Dusan Stosic, Dong Yang, and Naoya Maruyama

现代神经网络的高效训练通常依赖于使用低精度数据类型。在 A100 GPU 上,float16 矩阵乘法和卷积的峰值性能比 float32 峰值性能快 16 倍。由于 float16 和 bfloat16 数据类型的大小仅为 float32 的一半,它们可以使带宽受限的核函数性能加倍,并减少训练网络所需的内存,从而支持更大的模型、更大的批次或更大的输入。使用像 torch.amp 这样的模块(简称“自动混合精度”)可以轻松获得低精度数据类型的速度和内存使用优势,同时保持收敛行为。

更快并使用更少内存总是更有优势——深度学习从业者可以测试更多的模型架构和超参数,并且可以训练更大、更强大的模型。不使用混合精度,训练像 Narayanan et al.Brown et al. 所述的超大型模型(即使有专家手工优化,也需要数千个 GPU 耗时数月训练)是不可行的。

我们之前讨论过混合精度技术(此处此处此处),本博文是对这些技术的总结,如果您是混合精度新手,也是一个入门介绍。

混合精度训练实践

混合精度训练技术——将低精度 float16 或 bfloat16 数据类型与 float32 数据类型并行使用——是广泛适用且有效的。图 1 展示了成功使用混合精度训练的部分模型示例,图 2 和图 3 展示了使用 torch.amp 的加速示例。

图 1:使用 float16 成功训练的深度学习工作负载抽样(来源)。

图 2:使用 torch.amp 在 NVIDIA 8xV100 上进行混合精度训练的性能对比使用 8xV100 GPU 进行 float32 训练。柱状图表示 torch.amp 相较于 float32 的加速因子。(越高越好。)(来源)。

图 3:使用 torch.amp 在 NVIDIA 8xA100 上的混合精度训练性能对比 8xV100 GPU。柱状图表示 A100 相较于 V100 的加速因子。(越高越好。)(来源)。

有关更多混合精度工作负载示例,请参阅 NVIDIA 深度学习示例仓库

3D 医学图像分析注视估计视频合成条件 GAN卷积 LSTM 中可以看到类似的性能图表。Huang et al. 的研究表明,在 V100 GPU 上,混合精度训练比 float32 快 1.5 到 5.5 倍,在 A100 GPU 上额外快 1.3 到 2.5 倍,适用于各种网络。在超大型网络上,对混合精度的需求更加明显。Narayanan et al. 报告称,在 1024 个 A100 GPU 上训练 GPT-3 175B(批次大小为 1536)需要 34 天,但估计使用 float32 将需要一年多的时间!

使用 torch.amp 开始混合精度训练

torch.amp 在 PyTorch 1.6 中引入,可以轻松利用 float16 或 bfloat16 数据类型进行混合精度训练。有关更多详细信息,请参阅这篇博文、此教程文档。图 4 展示了将 AMP 与梯度缩放应用于网络的示例。

import torch
# Creates once at the beginning of training
scaler = torch.cuda.amp.GradScaler()

for data, label in data_iter:
   optimizer.zero_grad()
   # Casts operations to mixed precision
   with torch.amp.autocast(device_type=“cuda”, dtype=torch.float16):
      loss = model(data)

   # Scales the loss, and calls backward()
   # to create scaled gradients
   scaler.scale(loss).backward()

   # Unscales gradients and calls
   # or skips optimizer.step()
   scaler.step(optimizer)

   # Updates the scale for next iteration
   scaler.update()

图 4:AMP 代码示例

选择正确的方法

开箱即用的 float16 或 bfloat16 混合精度训练对于加速许多深度学习模型的收敛是有效的,但有些模型可能需要更精细的数值精度管理。以下是一些选项:

  • 完整 float32 精度。PyTorch 中默认以 float32 精度创建浮点张量和模块,但这只是一个历史遗留,不代表大多数现代深度学习网络的训练方式。网络很少需要这么高的数值精度。
  • 启用 TensorFloat32 (TF32) 模式。在 Ampere 及更高版本的 CUDA 设备上,矩阵乘法和卷积可以使用 TensorFloat32 (TF32) 模式,计算更快但精度略低。有关更多详细信息,请参阅使用 NVIDIA TF32 Tensor Core 加速 AI 训练这篇博文。默认情况下,PyTorch 为卷积启用 TF32 模式,但不为矩阵乘法启用,除非网络需要完整 float32 精度,否则我们建议也为矩阵乘法启用此设置(请参阅此处的文档了解如何操作)。它可以在数值精度损失通常可忽略不计的情况下显著加速计算。
  • 使用 torch.amp 结合 bfloat16 或 float16。这两种低精度浮点数据类型通常速度相当,但有些网络可能只能使用其中一种收敛。如果网络需要更高精度,可能需要使用 float16;如果网络需要更大的动态范围,可能需要使用 bfloat16,其动态范围等于 float32。例如,如果观察到溢出,我们建议尝试 bfloat16。

还有比此处介绍的更高级的选项,例如仅对模型的某些部分使用 torch.amp 的自动类型转换,或直接管理混合精度。这些主题在很大程度上超出了本博文的范围,但请参阅下方的“最佳实践”部分。

最佳实践

我们强烈建议在训练网络时,只要可能就使用 torch.amp 或 TF32 模式(在 Ampere 及更高版本的 CUDA 设备上)进行混合精度训练。但是,如果其中一种方法不起作用,我们建议以下方法:

  • 高性能计算 (HPC) 应用、回归任务和生成网络可能需要完整 float32 IEEE 精度才能按预期收敛。
  • 尝试选择性应用 torch.amp。特别地,我们建议首先在执行 torch.linalg 模块操作的区域或进行预处理或后处理时禁用它。这些操作通常特别敏感。请注意,TF32 模式是全局开关,不能在网络的特定区域选择性使用。请先启用 TF32 检查网络的运算符是否对该模式敏感,否则禁用它。
  • 如果您在使用 torch.amp 时遇到类型不匹配,我们不建议一开始就插入手动类型转换。此错误表明网络存在异常,通常值得先调查。
  • 通过实验确定您的网络是否对格式的范围和/或精度敏感。例如,在 float16 中微调 bfloat16 预训练模型很容易在 float16 中遇到范围问题,因为在 bfloat16 中训练可能产生较大范围,因此如果模型是在 bfloat16 中训练的,用户应坚持使用 bfloat16 微调。
  • 混合精度训练的性能提升取决于多种因素(例如计算密集型问题与内存密集型问题),用户应使用调优指南来消除训练脚本中的其他瓶颈。尽管 BF16 和 FP16 具有类似的理论性能优势,但在实践中速度可能不同。建议尝试提到的格式,并在保持所需数值行为的同时使用速度最佳的格式。

更多详细信息,请参阅AMP 教程使用 Tensor Core 训练神经网络,并查看 PyTorch 开发者讨论上的博文“浮点精度更深入的细节”。

结论

混合精度训练是在现代硬件上训练深度学习模型的基本工具,并且随着新硬件上低精度操作与 float32 之间的性能差距不断扩大(体现在图 5 中),未来会变得更加重要。

图 5:Volta 和 Ampere GPU 上 float16 (FP16) 相较于 float32 矩阵乘法的相对峰值吞吐量。在 Ampere 上,还展示了 TensorFloat32 (TF32) 模式和 bfloat16 矩阵乘法的相对峰值吞吐量。随着新硬件的发布,float16 和 bfloat16 等低精度数据类型相较于 float32 矩阵乘法的相对峰值吞吐量预计会增长。

PyTorch 的 torch.amp 模块使得开始使用混合精度变得容易,我们强烈建议使用它来更快训练并减少内存使用。torch.amp 支持 float16 和 bfloat16 混合精度。

仍有一些网络使用混合精度训练起来比较棘手,对于这些网络,我们建议在 Ampere 及更高版本的 CUDA 硬件上尝试 TF32 加速的矩阵乘法。网络很少对精度如此敏感,以至于每个操作都需要完整 float32 精度。

如果您对 PyTorch 中的 torch.amp 或混合精度支持有疑问或建议,请通过在PyTorch 论坛的混合精度类别发帖或在PyTorch GitHub 页面上提交问题告知我们。