现代神经网络的高效训练通常依赖于低精度数据类型。在 A100 GPU 上,float16 矩阵乘法和卷积的峰值性能比 float32 快 16 倍。由于 float16 和 bfloat16 数据类型的大小仅为 float32 的一半,它们可以将带宽受限算子的性能提高一倍,并减少训练网络所需的内存,从而支持更大的模型、更大的批次或更大的输入。使用像 torch.amp(“自动混合精度”的缩写)这样的模块,可以轻松获得低精度数据类型带来的速度和内存优势,同时保持收敛行为。

追求更快的速度和更少的内存占用始终是有利的——深度学习从业者可以测试更多的模型架构和超参数,并训练更大、更强大的模型。若不使用混合精度,训练像 Narayanan 等人Brown 等人 所述的那种超大规模模型(即使有专家手动优化,也需要数千个 GPU 训练数月)是不可行的。

我们之前讨论过混合精度技术(此处此处此处),本文旨在总结这些技术,并为刚接触混合精度的用户提供入门介绍。

混合精度训练实践

混合精度训练技术——即在 float32 数据类型的同时使用低精度的 float16 或 bfloat16 数据类型——具有广泛的适用性和有效性。参见图 1 了解成功应用混合精度训练的模型示例,图 2 和图 3 展示了使用 torch.amp 带来的加速效果。

图 1:成功使用 float16 训练的深度学习工作负载示例(来源)。

图 2:在 NVIDIA 8xV100 上使用 torch.amp 进行混合精度训练与在 8xV100 GPU 上使用 float32 训练的性能对比。柱状图表示 torch.amp 相对于 float32 的加速比。(数值越高越好。)(来源)。

图 3:在 NVIDIA 8xA100 上使用 torch.amp 进行混合精度训练与在 8xV100 GPU 上使用 float32 训练的性能对比。柱状图表示 A100 相对于 V100 的加速比。(数值越高越好。)(来源)。

有关更多混合精度工作负载示例,请查看 NVIDIA 深度学习示例仓库

类似的性能图表可见于 3D 医学图像分析视线估计视频合成条件 GANs卷积 LSTMsHuang 等人 表明,在 V100 GPU 上,混合精度训练比 float32 快 1.5 到 5.5 倍;在 A100 GPU 上,针对各种网络,速度进一步提升 1.3 到 2.5 倍。在超大型网络中,混合精度的必要性更为显著。Narayanan 等人 报告称,使用 1024 个 A100 GPU(批次大小为 1536)训练 GPT-3 175B 需要 34 天,但如果使用 float32,估计需要一年多!

使用 torch.amp 开始混合精度训练

PyTorch 1.6 中引入的 torch.amp 使得使用 float16 或 bfloat16 数据类型进行混合精度训练变得简单。更多详细信息,请参阅此博客文章教程文档。图 4 展示了将带有梯度缩放(grad scaling)的 AMP 应用于网络的示例。

import torch
# Creates once at the beginning of training
scaler = torch.cuda.amp.GradScaler()

for data, label in data_iter:
   optimizer.zero_grad()
   # Casts operations to mixed precision
   with torch.amp.autocast(device_type=“cuda”, dtype=torch.float16):
      loss = model(data)

   # Scales the loss, and calls backward()
   # to create scaled gradients
   scaler.scale(loss).backward()

   # Unscales gradients and calls
   # or skips optimizer.step()
   scaler.step(optimizer)

   # Updates the scale for next iteration
   scaler.update()

图 4:AMP 使用配方

选择正确的方法

使用 float16 或 bfloat16 进行开箱即用的混合精度训练,对于加速许多深度学习模型的收敛非常有效,但某些模型可能需要更精细的数值精度管理。以下是一些可选方案:

  • 完整 float32 精度:在 PyTorch 中,浮点张量和模块默认以 float32 精度创建,但这属于历史遗留问题,并不代表大多数现代深度学习网络的训练需求。网络很少需要如此高的数值精度。
  • 启用 TensorFloat32 (TF32) 模式:在 Ampere 及更新的 CUDA 设备上,矩阵乘法和卷积可以使用 TensorFloat32 (TF32) 模式,以实现更快但精度略微降低的计算。更多详情请参阅 使用 NVIDIA TF32 Tensor Cores 加速 AI 训练 博客文章。PyTorch 默认对卷积启用 TF32 模式,但不对矩阵乘法启用。除非网络需要完整的 float32 精度,否则我们建议对矩阵乘法也启用此设置(参见此处的文档了解操作方法)。它可以在数值精度损失几乎可忽略的情况下显著加速计算。
  • 使用 torch.amp 配合 bfloat16 或 float16:这两种低精度浮点数据类型通常具有相当的速度,但某些网络可能只能在其中一种上收敛。如果网络需要更高精度,可能需要使用 float16;如果网络需要更大的动态范围,则可能需要使用 bfloat16(其动态范围与 float32 相当)。例如,如果观察到溢出(overflows),建议尝试 bfloat16。

除了上述方法外,还有更高级的选项,例如仅对模型的部分区域使用 torch.amp 的自动转换(autocasting),或直接管理混合精度。这些主题大多超出了本博客文章的范围,但请参考下方的“最佳实践”部分。

最佳实践

我们强烈建议在训练网络时,尽可能使用 torch.amp 或 TF32 模式(在 Ampere 及更新的 CUDA 设备上)。如果这些方法都不奏效,建议采取以下措施:

  • 高性能计算 (HPC) 应用、回归任务和生成式网络可能确实需要完整的 float32 IEEE 精度才能达到预期的收敛效果。
  • 尝试有选择地应用 torch.amp。我们特别建议首先在执行 torch.linalg 模块算子的区域或进行预处理/后处理时禁用它。这些操作通常特别敏感。请注意,TF32 模式是一个全局开关,不能在网络的特定区域有选择地使用。应先启用 TF32 以检查网络的算子是否对该模式敏感,若敏感则禁用。
  • 如果您在使用 torch.amp 时遇到类型不匹配(type mismatches),我们不建议直接手动插入强制转换。此错误表明网络可能存在某些问题,通常值得优先调查。
  • 通过实验确定您的网络对格式的范围和/或精度是否敏感。例如,使用 float16 微调 bfloat16 预训练模型 很容易在 float16 中出现范围问题,因为 bfloat16 训练的范围可能很大。如果模型是在 bfloat16 下训练的,用户应坚持使用 bfloat16 进行微调。
  • 混合精度训练的性能增益取决于多种因素(例如计算密集型与内存密集型问题),用户应使用 调优指南 来消除训练脚本中的其他瓶颈。尽管理论性能优势相似,但 BF16 和 FP16 在实践中可能表现出不同的速度。建议尝试上述格式,并在保持理想数值行为的前提下,选择速度最快的一种。

有关更多详细信息,请参阅 AMP 教程使用 Tensor Cores 训练神经网络,并查看 PyTorch 开发讨论区上的“浮点精度的深度细节”一文。

结论

混合精度训练是在现代硬件上训练深度学习模型的重要工具。随着图 5 所显示的硬件演进,低精度算子与 float32 之间的性能差距不断拉大,混合精度在未来将变得愈发重要。

图 5:在 Volta 和 Ampere GPU 上,float16 (FP16) 与 float32 矩阵乘法的相对峰值吞吐量。Ampere 上也显示了 TensorFloat32 (TF32) 模式和 bfloat16 矩阵乘法的相对峰值吞吐量。预计随着新硬件的发布,float16 和 bfloat16 等低精度数据类型与 float32 矩阵乘法的相对峰值吞吐量差距将进一步扩大。

PyTorch 的 torch.amp 模块使混合精度入门变得简单,我们强烈建议使用它来加速训练并减少内存占用。torch.amp 同时支持 float16 和 bfloat16 混合精度。

仍有一些网络难以通过混合精度进行训练,对于这些网络,我们建议在 Ampere 及更新的 CUDA 硬件上尝试 TF32 加速的矩阵乘法。很少有网络会对精度敏感到必须对每个算子都使用完整的 float32 精度。

如果您对 PyTorch 中的 torch.amp 或混合精度支持有任何疑问或建议,请通过 PyTorch 论坛的混合精度板块 发帖,或在 PyTorch GitHub 页面上提交 issue 告知我们。