现代神经网络的高效训练通常依赖于使用低精度数据类型。在 A100 GPU 上,float16 矩阵乘法和卷积的峰值性能比 float32 峰值性能快 16 倍。由于 float16 和 bfloat16 数据类型的大小仅为 float32 的一半,它们可以使带宽受限的核函数性能加倍,并减少训练网络所需的内存,从而支持更大的模型、更大的批次或更大的输入。使用像 torch.amp 这样的模块(简称“自动混合精度”)可以轻松获得低精度数据类型的速度和内存使用优势,同时保持收敛行为。
更快并使用更少内存总是更有优势——深度学习从业者可以测试更多的模型架构和超参数,并且可以训练更大、更强大的模型。不使用混合精度,训练像 Narayanan et al. 和 Brown et al. 所述的超大型模型(即使有专家手工优化,也需要数千个 GPU 耗时数月训练)是不可行的。
我们之前讨论过混合精度技术(此处、此处和此处),本博文是对这些技术的总结,如果您是混合精度新手,也是一个入门介绍。
混合精度训练实践
混合精度训练技术——将低精度 float16 或 bfloat16 数据类型与 float32 数据类型并行使用——是广泛适用且有效的。图 1 展示了成功使用混合精度训练的部分模型示例,图 2 和图 3 展示了使用 torch.amp 的加速示例。
图 1:使用 float16 成功训练的深度学习工作负载抽样(来源)。
图 2:使用 torch.amp 在 NVIDIA 8xV100 上进行混合精度训练的性能对比使用 8xV100 GPU 进行 float32 训练。柱状图表示 torch.amp 相较于 float32 的加速因子。(越高越好。)(来源)。
图 3:使用 torch.amp 在 NVIDIA 8xA100 上的混合精度训练性能对比 8xV100 GPU。柱状图表示 A100 相较于 V100 的加速因子。(越高越好。)(来源)。
有关更多混合精度工作负载示例,请参阅 NVIDIA 深度学习示例仓库。
在 3D 医学图像分析、注视估计、视频合成、条件 GAN 和 卷积 LSTM 中可以看到类似的性能图表。Huang et al. 的研究表明,在 V100 GPU 上,混合精度训练比 float32 快 1.5 到 5.5 倍,在 A100 GPU 上额外快 1.3 到 2.5 倍,适用于各种网络。在超大型网络上,对混合精度的需求更加明显。Narayanan et al. 报告称,在 1024 个 A100 GPU 上训练 GPT-3 175B(批次大小为 1536)需要 34 天,但估计使用 float32 将需要一年多的时间!
使用 torch.amp 开始混合精度训练
torch.amp 在 PyTorch 1.6 中引入,可以轻松利用 float16 或 bfloat16 数据类型进行混合精度训练。有关更多详细信息,请参阅这篇博文、此教程和文档。图 4 展示了将 AMP 与梯度缩放应用于网络的示例。
import torch
# Creates once at the beginning of training
scaler = torch.cuda.amp.GradScaler()
for data, label in data_iter:
optimizer.zero_grad()
# Casts operations to mixed precision
with torch.amp.autocast(device_type=“cuda”, dtype=torch.float16):
loss = model(data)
# Scales the loss, and calls backward()
# to create scaled gradients
scaler.scale(loss).backward()
# Unscales gradients and calls
# or skips optimizer.step()
scaler.step(optimizer)
# Updates the scale for next iteration
scaler.update()
图 4:AMP 代码示例
选择正确的方法
开箱即用的 float16 或 bfloat16 混合精度训练对于加速许多深度学习模型的收敛是有效的,但有些模型可能需要更精细的数值精度管理。以下是一些选项:
- 完整 float32 精度。PyTorch 中默认以 float32 精度创建浮点张量和模块,但这只是一个历史遗留,不代表大多数现代深度学习网络的训练方式。网络很少需要这么高的数值精度。
- 启用 TensorFloat32 (TF32) 模式。在 Ampere 及更高版本的 CUDA 设备上,矩阵乘法和卷积可以使用 TensorFloat32 (TF32) 模式,计算更快但精度略低。有关更多详细信息,请参阅使用 NVIDIA TF32 Tensor Core 加速 AI 训练这篇博文。默认情况下,PyTorch 为卷积启用 TF32 模式,但不为矩阵乘法启用,除非网络需要完整 float32 精度,否则我们建议也为矩阵乘法启用此设置(请参阅此处的文档了解如何操作)。它可以在数值精度损失通常可忽略不计的情况下显著加速计算。
- 使用 torch.amp 结合 bfloat16 或 float16。这两种低精度浮点数据类型通常速度相当,但有些网络可能只能使用其中一种收敛。如果网络需要更高精度,可能需要使用 float16;如果网络需要更大的动态范围,可能需要使用 bfloat16,其动态范围等于 float32。例如,如果观察到溢出,我们建议尝试 bfloat16。
还有比此处介绍的更高级的选项,例如仅对模型的某些部分使用 torch.amp 的自动类型转换,或直接管理混合精度。这些主题在很大程度上超出了本博文的范围,但请参阅下方的“最佳实践”部分。
最佳实践
我们强烈建议在训练网络时,只要可能就使用 torch.amp 或 TF32 模式(在 Ampere 及更高版本的 CUDA 设备上)进行混合精度训练。但是,如果其中一种方法不起作用,我们建议以下方法:
- 高性能计算 (HPC) 应用、回归任务和生成网络可能需要完整 float32 IEEE 精度才能按预期收敛。
- 尝试选择性应用 torch.amp。特别地,我们建议首先在执行 torch.linalg 模块操作的区域或进行预处理或后处理时禁用它。这些操作通常特别敏感。请注意,TF32 模式是全局开关,不能在网络的特定区域选择性使用。请先启用 TF32 检查网络的运算符是否对该模式敏感,否则禁用它。
- 如果您在使用 torch.amp 时遇到类型不匹配,我们不建议一开始就插入手动类型转换。此错误表明网络存在异常,通常值得先调查。
- 通过实验确定您的网络是否对格式的范围和/或精度敏感。例如,在 float16 中微调 bfloat16 预训练模型很容易在 float16 中遇到范围问题,因为在 bfloat16 中训练可能产生较大范围,因此如果模型是在 bfloat16 中训练的,用户应坚持使用 bfloat16 微调。
- 混合精度训练的性能提升取决于多种因素(例如计算密集型问题与内存密集型问题),用户应使用调优指南来消除训练脚本中的其他瓶颈。尽管 BF16 和 FP16 具有类似的理论性能优势,但在实践中速度可能不同。建议尝试提到的格式,并在保持所需数值行为的同时使用速度最佳的格式。
更多详细信息,请参阅AMP 教程、使用 Tensor Core 训练神经网络,并查看 PyTorch 开发者讨论上的博文“浮点精度更深入的细节”。
结论
混合精度训练是在现代硬件上训练深度学习模型的基本工具,并且随着新硬件上低精度操作与 float32 之间的性能差距不断扩大(体现在图 5 中),未来会变得更加重要。
图 5:Volta 和 Ampere GPU 上 float16 (FP16) 相较于 float32 矩阵乘法的相对峰值吞吐量。在 Ampere 上,还展示了 TensorFloat32 (TF32) 模式和 bfloat16 矩阵乘法的相对峰值吞吐量。随着新硬件的发布,float16 和 bfloat16 等低精度数据类型相较于 float32 矩阵乘法的相对峰值吞吐量预计会增长。
PyTorch 的 torch.amp 模块使得开始使用混合精度变得容易,我们强烈建议使用它来更快训练并减少内存使用。torch.amp 支持 float16 和 bfloat16 混合精度。
仍有一些网络使用混合精度训练起来比较棘手,对于这些网络,我们建议在 Ampere 及更高版本的 CUDA 硬件上尝试 TF32 加速的矩阵乘法。网络很少对精度如此敏感,以至于每个操作都需要完整 float32 精度。
如果您对 PyTorch 中的 torch.amp 或混合精度支持有疑问或建议,请通过在PyTorch 论坛的混合精度类别发帖或在PyTorch GitHub 页面上提交问题告知我们。