作者:Syed Ahmed、Christian Sarofeen、Mike Ruberry、Eddie Yan、Natalia Gimelshein、Michael Carilli、Szymon Migacz、Piotr Bialecki、Paulius Micikevicius、Dusan Stosic、Dong Yang 和 Naoya Maruyama

现代神经网络的高效训练通常依赖于使用较低精度的数据类型。在 A100 GPU 上,float16 矩阵乘法和卷积的峰值性能比 float32 的峰值性能快 16 倍。由于 float16 和 bfloat16 数据类型的大小仅为 float32 的一半,它们可以将带宽受限内核的性能提高一倍,并减少训练网络所需的内存,从而允许更大的模型、更大的批次或更大的输入。使用像 torch.amp(“自动混合精度”的缩写)这样的模块可以轻松获得较低精度数据类型的速度和内存使用优势,同时保持收敛行为。

更快地运行和使用更少的内存始终是有利的——深度学习从业者可以测试更多的模型架构和超参数,并且可以训练更大、更强大的模型。训练非常大的模型,例如 Narayanan 等人Brown 等人 中描述的模型(即使有专家手写的优化,也需要数千个 GPU 训练数月),如果不使用混合精度,则是不可行的。

我们之前已经讨论过混合精度技术(这里这里这里),这篇博客文章是对这些技术的总结,如果您是混合精度新手,则可以作为入门介绍。

混合精度训练实践

混合精度训练技术——将较低精度的 float16 或 bfloat16 数据类型与 float32 数据类型一起使用——具有广泛的适用性和有效性。参见图 1,其中展示了使用混合精度成功训练的模型示例,图 2 和图 3 展示了使用 torch.amp 的加速示例。

图 1:使用 float16 成功训练的 DL 工作负载示例 (来源)。

图 2:在 NVIDIA 8xV100 上使用 torch.amp 进行混合精度训练与在 8xV100 GPU 上进行 float32 训练的性能对比。柱状图表示 torch.amp 相对于 float32 的加速因子。(越高越好。)(来源)。

图 3. 在 NVIDIA 8xA100 上使用 torch.amp 进行混合精度训练与 8xV100 GPU 的性能对比。柱状图表示 A100 相对于 V100 的加速因子。(越高越好。)(来源)。

有关更多混合精度工作负载示例,请参阅 NVIDIA 深度学习示例存储库

3D 医学图像分析注视估计视频合成条件 GAN卷积 LSTM 中可以看到类似的性能图表。Huang 等人 表明,在 V100 GPU 上,混合精度训练比 float32 快 1.5 倍到 5.5 倍,在各种网络上,在 A100 GPU 上则快 1.3 倍到 2.5 倍。在非常大的网络上,对混合精度的需求更加明显。Narayanan 等人 报告称,在 1024 个 A100 GPU 上训练 GPT-3 175B(批大小为 1536)需要 34 天,但据估计,使用 float32 则需要一年以上!

开始使用 torch.amp 进行混合精度训练

PyTorch 1.6 中引入的 torch.amp 可以轻松利用 float16 或 bfloat16 数据类型进行混合精度训练。有关更多详细信息,请参阅这篇 博客文章教程文档。图 4 展示了将 AMP 与梯度缩放应用于网络的示例。

import torch
# Creates once at the beginning of training
scaler = torch.cuda.amp.GradScaler()

for data, label in data_iter:
   optimizer.zero_grad()
   # Casts operations to mixed precision
   with torch.amp.autocast(device_type=“cuda”, dtype=torch.float16):
      loss = model(data)

   # Scales the loss, and calls backward()
   # to create scaled gradients
   scaler.scale(loss).backward()

   # Unscales gradients and calls
   # or skips optimizer.step()
   scaler.step(optimizer)

   # Updates the scale for next iteration
   scaler.update()

图 4:AMP 食谱

选择正确的方法

使用 float16 或 bfloat16 进行开箱即用的混合精度训练可以有效地加速许多深度学习模型的收敛,但某些模型可能需要更仔细的数值精度管理。以下是一些选项

  • 全 float32 精度。默认情况下,PyTorch 中的浮点张量和模块以 float32 精度创建,但这是一种历史遗留产物,不能代表训练大多数现代深度学习网络。网络很少需要如此高的数值精度。
  • 启用 TensorFloat32 (TF32) 模式。在 Ampere 及更高版本的 CUDA 设备上,矩阵乘法和卷积可以使用 TensorFloat32 (TF32) 模式进行更快但精度稍低的计算。有关更多详细信息,请参阅 使用 NVIDIA TF32 Tensor Core 加速 AI 训练 博客文章。默认情况下,PyTorch 为卷积启用 TF32 模式,但不为矩阵乘法启用,除非网络需要全 float32 精度,否则我们建议也为矩阵乘法启用此设置(有关如何操作,请参阅 此处 的文档)。它可以显着加速计算,而通常数值精度损失可忽略不计。
  • 使用 torch.amp 和 bfloat16 或 float16。这两种低精度浮点数据类型通常速度相当,但某些网络可能仅在使用其中一种时才能收敛。如果网络需要更高的精度,则可能需要使用 float16,如果网络需要更大的动态范围,则可能需要使用 bfloat16,其动态范围与 float32 相同。如果观察到溢出,例如,我们建议尝试 bfloat16。

还有比此处介绍的更高级的选项,例如仅对模型的部分使用 torch.amp 的自动转换,或直接管理混合精度。这些主题在很大程度上超出了本博客文章的范围,但请参阅下面的“最佳实践”部分。

最佳实践

我们强烈建议在训练网络时,尽可能使用 torch.amp 或 TF32 模式(在 Ampere 及更高版本的 CUDA 设备上)进行混合精度训练。但是,如果这些方法之一不起作用,我们建议以下操作

  • 高性能计算 (HPC) 应用程序、回归任务和生成网络可能确实需要全 float32 IEEE 精度才能按预期收敛。
  • 尝试选择性地应用 torch.amp。特别是,我们建议首先在执行来自 torch.linalg 模块的操作或进行预处理或后处理的区域禁用它。这些操作通常特别敏感。请注意,TF32 模式是一个全局开关,不能在网络的区域上选择性地使用。首先启用 TF32 以检查网络的运算符是否对该模式敏感,否则禁用它。
  • 如果您在使用 torch.amp 时遇到类型不匹配,我们不建议一开始就插入手动转换。此错误表明网络存在问题,通常值得首先调查。
  • 通过实验确定您的网络是否对格式的范围和/或精度敏感。例如,由于在 bfloat16 中训练的潜在大范围,在 float16 中微调 bfloat16 预训练模型 很容易遇到 float16 中的范围问题,因此如果模型是在 bfloat16 中训练的,用户应坚持使用 bfloat16 微调。
  • 混合精度训练的性能提升可能取决于多种因素(例如,计算密集型与内存密集型问题),用户应使用 调优指南 来消除其训练脚本中的其他瓶颈。尽管 BF16 和 FP16 具有相似的理论性能优势,但在实践中可能具有不同的速度。建议尝试上述格式,并在保持所需数值行为的同时使用速度最佳的格式。

有关更多详细信息,请参阅 AMP 教程使用 Tensor Core 训练神经网络,并参阅 PyTorch Dev Discussion 上的帖子“浮点精度的更深入细节”。

结论

混合精度训练是在现代硬件上训练深度学习模型的必要工具,并且随着较低精度运算和 float32 之间的性能差距在新硬件上持续增长,它将在未来变得更加重要,如图 5 所示。

图 5:Volta 和 Ampere GPU 上 float16 (FP16) 与 float32 矩阵乘法的相对峰值吞吐量。在 Ampere 上,也显示了 TensorFloat32 (TF32) 模式和 bfloat16 矩阵乘法的相对峰值吞吐量。随着新硬件的发布,像 float16 和 bfloat16 这样的低精度数据类型相对于 float32 矩阵乘法的相对峰值吞吐量预计会增长。

PyTorch 的 torch.amp 模块使您可以轻松开始使用混合精度,我们强烈建议使用它来更快地训练并减少内存使用。torch.amp 同时支持 float16 和 bfloat16 混合精度。

仍然有一些网络难以使用混合精度进行训练,对于这些网络,我们建议在 Ampere 及更高版本的 CUDA 硬件上尝试 TF32 加速矩阵乘法。网络很少对精度如此敏感,以至于每个操作都需要全 float32 精度。

如果您对 PyTorch 中 torch.amp 或混合精度支持有任何疑问或建议,请通过在 PyTorch 论坛的混合精度类别 中发帖或 在 PyTorch GitHub 页面上提交 issue 的方式告知我们。