跳转到主要内容

现代神经网络的有效训练通常依赖于使用较低精度的数据类型。在 A100 GPU 上,float16 矩阵乘法和卷积的峰值性能比 float32 峰值性能快 16 倍。由于 float16 和 bfloat16 数据类型的大小只有 float32 的一半,它们可以使带宽受限内核的性能翻倍,并减少训练网络所需的内存,从而允许使用更大的模型、更大的批次或更大的输入。使用像 torch.amp(“自动混合精度”的缩写)这样的模块,可以轻松获得较低精度数据类型的速度和内存使用优势,同时保持收敛行为。

更快地运行和使用更少的内存总是有利的——深度学习从业者可以测试更多的模型架构和超参数,并且可以训练更大、更强大的模型。训练像 Narayanan 等人Brown 等人 描述的那样非常大的模型(即使有专家手写优化,也需要数千个 GPU 数月才能训练完成),如果没有使用混合精度,这是不可行的。

我们之前已经讨论过混合精度技术(这里这里这里),这篇博客文章是对这些技术的总结,也是对混合精度新手的介绍。

混合精度训练实践

混合精度训练技术——使用较低精度的 float16 或 bfloat16 数据类型以及 float32 数据类型——具有广泛的适用性和有效性。参见图 1 了解成功使用混合精度训练的模型示例,参见图 2 和图 3 了解使用 torch.amp 的加速示例。

图 1:成功使用 float16 训练的 DL 工作负载样本(来源)。

图 2:在 NVIDIA 8xV100 上使用 torch.amp 进行混合精度训练与在 8xV100 GPU 上进行 float32 训练的性能比较。条形图表示 torch.amp 相对于 float32 的加速因子。(越高越好。)(来源)。

图 3:在 NVIDIA 8xA100 上使用 torch.amp 进行混合精度训练与 8xV100 GPU 的性能比较。条形图表示 A100 相对于 V100 的加速因子。(越高越好。)(来源)。

有关更多混合精度工作负载示例,请参见 NVIDIA 深度学习示例仓库

类似的性能图表可以在 3D 医学图像分析注视估计视频合成条件 GANs卷积 LSTMs 中看到。Huang 等人 的研究表明,在 V100 GPU 上,混合精度训练比 float32 快 1.5 到 5.5 倍,在 A100 GPU 上,在各种网络上额外快 1.3 到 2.5 倍。在非常大的网络上,对混合精度的需求更加明显。Narayanan 等人 报告称,在 1024 个 A100 GPU 上训练 GPT-3 175B(批大小为 1536)需要 34 天,但估计使用 float32 将需要一年多!

使用 torch.amp 开始混合精度训练

PyTorch 1.6 中引入的 torch.amp 使得利用 float16 或 bfloat16 dtypes 进行混合精度训练变得容易。有关更多详细信息,请参见这篇博客文章教程文档。图 4 显示了一个将 AMP 与梯度缩放应用于网络的示例。

import torch
# Creates once at the beginning of training
scaler = torch.cuda.amp.GradScaler()

for data, label in data_iter:
   optimizer.zero_grad()
   # Casts operations to mixed precision
   with torch.amp.autocast(device_type=“cuda”, dtype=torch.float16):
      loss = model(data)

   # Scales the loss, and calls backward()
   # to create scaled gradients
   scaler.scale(loss).backward()

   # Unscales gradients and calls
   # or skips optimizer.step()
   scaler.step(optimizer)

   # Updates the scale for next iteration
   scaler.update()

图 4:AMP 教程

选择正确的方法

开箱即用的 float16 或 bfloat16 混合精度训练对于加速许多深度学习模型的收敛是有效的,但有些模型可能需要更仔细的数值精度管理。以下是一些选项:

  • 完整的 float32 精度。在 PyTorch 中,浮点张量和模块默认以 float32 精度创建,但这只是一个历史遗留问题,不能代表大多数现代深度学习网络的训练情况。网络很少需要如此高的数值精度。
  • 启用 TensorFloat32 (TF32) 模式。在 Ampere 及更高版本的 CUDA 设备上,矩阵乘法和卷积可以使用 TensorFloat32 (TF32) 模式进行更快但精度稍低的计算。有关更多详细信息,请参见 使用 NVIDIA TF32 Tensor Cores 加速 AI 训练 博客文章。PyTorch 默认对卷积启用 TF32 模式,但不对矩阵乘法启用,除非网络需要完整的 float32 精度,否则我们建议也为矩阵乘法启用此设置(有关如何操作,请参见此处的文档)。它可以显著加速计算,通常数值精度损失可以忽略不计。
  • 使用 torch.amp 与 bfloat16 或 float16。这两种低精度浮点数据类型通常速度相当,但有些网络可能只用其中一种才能收敛。如果网络需要更高的精度,可能需要使用 float16;如果网络需要更大的动态范围,可能需要使用 bfloat16,其动态范围与 float32 相同。例如,如果观察到溢出,我们建议尝试 bfloat16。

除了这里介绍的选项之外,还有更高级的选项,例如仅对模型的某些部分使用 torch.amp 的自动类型转换,或者直接管理混合精度。这些主题大多超出了本博客文章的范围,但请参阅下面的“最佳实践”部分。

最佳实践

我们强烈建议在训练网络时尽可能使用 torch.amp 或 TF32 模式(在 Ampere 及更高版本的 CUDA 设备上)进行混合精度训练。但是,如果其中一种方法不起作用,我们建议以下方法:

  • 高性能计算 (HPC) 应用程序、回归任务和生成网络可能需要完整的 float32 IEEE 精度才能按预期收敛。
  • 尝试选择性地应用 torch.amp。特别是,我们建议首先在执行 `torch.linalg` 模块操作或进行预处理/后处理的区域禁用它。这些操作通常特别敏感。请注意,TF32 模式是一个全局开关,不能选择性地应用于网络的某些区域。首先启用 TF32 检查网络的运算符是否对该模式敏感,否则禁用它。
  • 如果在 torch.amp 中遇到类型不匹配,我们不建议一开始就插入手动转换。这个错误表明网络可能存在问题,通常值得首先进行调查。
  • 通过实验找出您的网络是否对格式的范围和/或精度敏感。例如,在 float16 中微调 bfloat16 预训练模型 很容易遇到 float16 中的范围问题,因为 bfloat16 训练可能具有较大的范围,因此如果模型是用 bfloat16 训练的,用户应坚持使用 bfloat16 微调。
  • 混合精度训练的性能增益可能取决于多种因素(例如计算密集型与内存密集型问题),用户应使用 调优指南 来消除训练脚本中的其他瓶颈。尽管 BF16 和 FP16 具有相似的理论性能优势,但在实践中它们可能具有不同的速度。建议尝试提及的格式,并在保持所需数值行为的同时使用速度最佳的格式。

有关更多详细信息,请参阅 AMP 教程使用 Tensor Cores 训练神经网络,并查看 PyTorch Dev Discussion 上的文章“浮点精度更深入的详细信息”。

结论

混合精度训练是现代硬件上训练深度学习模型的重要工具,随着较低精度操作与 float32 之间的性能差距在新硬件上持续扩大(如图 5 所示),它在未来将变得更加重要。

图 5:Volta 和 Ampere GPU 上 float16 (FP16) 与 float32 矩阵乘法的相对峰值吞吐量。在 Ampere 上,还显示了 TensorFloat32 (TF32) 模式和 bfloat16 矩阵乘法的相对峰值吞吐量。随着新硬件的发布,float16 和 bfloat16 等低精度数据类型与 float32 矩阵乘法的相对峰值吞吐量预计将增长。

PyTorch 的 torch.amp 模块使混合精度训练变得容易上手,我们强烈建议使用它来更快地训练并减少内存使用。torch.amp 支持 float16 和 bfloat16 混合精度。

仍然有一些网络难以使用混合精度进行训练,对于这些网络,我们建议在 Ampere 及更高版本的 CUDA 硬件上尝试 TF32 加速的矩阵乘法。网络很少会如此对精度敏感,以至于每个操作都需要完整的 float32 精度。

如果您对 PyTorch 中的 torch.amp 或混合精度支持有疑问或建议,请通过在 PyTorch 论坛的混合精度类别 中发帖或在 PyTorch GitHub 页面上提交问题 告知我们。