您使用随机梯度下降 (SGD) 还是 Adam?无论您使用哪种程序来训练您的神经网络,您都可以使用 PyTorch 1.6 现在原生支持的一种简单新技术——随机权重平均 (SWA) [1],以几乎没有额外成本的方式实现显著更好的泛化能力。即使您已经训练了您的模型,也可以通过从预训练模型开始,运行 SWA 少量 epoch 来轻松实现 SWA 的优势。研究人员一次又一次地发现,SWA 以很小的成本或努力就能提高经过良好调整的模型在各种实际应用中的性能!
SWA 具有广泛的应用和特性
- 与计算机视觉中的标准训练技术相比(例如,ImageNet 和 CIFAR 基准上的 VGG、ResNets、Wide ResNets 和 DenseNets [1, 2]),SWA 显著提高了性能。
- SWA 在半监督学习和域适应的关键基准上提供了最先进的性能 [2]。
- SWA 被证明可以改善语言建模(例如,WikiText-2 上的 AWD-LSTM [4])和深度强化学习中的策略梯度方法 [3] 的性能。
- SWAG 是 SWA 的扩展,可以近似贝叶斯深度学习中的贝叶斯模型平均,并在各种设置中实现最先进的不确定性校准结果。此外,其最近的泛化 MultiSWAG 提供了显著的额外性能提升并缓解了双重下降 [4, 10]。另一种方法,子空间推理,近似 SWA 解周围参数空间中一个小子空间中的贝叶斯后验 [5]。
- 用于低精度训练的 SWA(SWALP)可以匹配全精度 SGD 训练的性能,即使所有数字都量化到 8 位,包括梯度累加器 [6]。
- 并行 SWA(SWAP)被证明通过使用大批量大小极大地加速了神经网络的训练,特别是通过在 27 秒内将神经网络训练到 CIFAR-10 上 94% 的准确率,创造了记录 [11]。

图 1。SWA 和 SGD 在 CIFAR-100 上使用预激活 ResNet-164 的示意图 [1]。左图:三个 FGE 样本和相应的 SWA 解(在权重空间中平均)的测试误差曲面。中图和右图:测试误差和训练损失曲面,显示从 SGD 经过 125 个训练 epoch 后相同的初始化开始,SGD(收敛时)和 SWA 提出的权重。请参阅 [1] 了解这些图的构建细节。
简而言之,SWA 对 SGD(或任何随机优化器)遍历的权重执行相等平均,并采用修改后的学习率调度(参见图 1 的左面板)。SWA 解决方案最终位于损失的宽平坦区域的中心,而 SGD 倾向于收敛到低损失区域的边界,使其容易受到训练和测试误差曲面之间偏移的影响(参见图 1 的中面板和右面板)。我们强调,SWA 可以与任何优化器(例如 Adam)一起使用,并且不限于 SGD。
以前,SWA 在 PyTorch contrib 中。在 PyTorch 1.6 中,我们在 torch.optim.swa_utils 中提供了 SWA 的新便捷实现。
这只是平均 SGD 吗?
从高层次来看,平均 SGD 迭代可以追溯到凸优化中的几十年前 [7, 8],有时被称为 Polyak-Ruppert 平均,或平均 SGD。但细节很重要。平均 SGD 通常与衰减学习率和指数移动平均 (EMA) 结合使用,通常用于凸优化。在凸优化中,重点在于改进收敛速度。在深度学习中,这种形式的平均 SGD 平滑了 SGD 迭代的轨迹,但表现并没有太大差异。
相比之下,SWA 使用 SGD 迭代的相等平均,并采用修改后的循环或高常数学习率,并利用深度学习中特有的训练目标的平坦性 [8] 来提高泛化能力。
随机权重平均是如何工作的?
SWA 成功运行有两个重要因素。首先,SWA 使用修改后的学习率调度,使得 SGD(或 Adam 等其他优化器)继续在最优解附近“跳动”并探索不同的模型,而不是简单地收敛到单一解决方案。例如,我们可以在训练时间的前 75% 使用标准的衰减学习率策略,然后将学习率设置为一个合理的常数高值,用于剩余的 25% 时间(参见下面的图 2)。第二个因素是取 SGD 遍历的网络权重的平均值(通常是相等平均)。例如,我们可以在训练时间的最后 25% 的每个 epoch 结束时,保持权重的运行平均值(参见图 2)。训练完成后,我们将网络的权重设置为计算出的 SWA 平均值。

图 2。SWA 采用的学习率调度示意图。前 75% 的训练时间使用标准衰减调度,然后剩余 25% 的时间使用高常数值。SWA 平均值在训练的最后 25% 期间形成。
一个重要的细节是批归一化。批归一化层在训练期间计算激活的运行统计信息。请注意,SWA 权重的平均值在训练期间从未用于进行预测。因此,批归一化层在训练结束时没有计算激活统计信息。我们可以通过使用 SWA 模型对训练数据进行一次前向传播来计算这些统计信息。
虽然我们为了描述简单起见专注于 SGD,但 SWA 可以与任何优化器结合使用。您也可以使用循环学习率而不是高常数值(例如,参见 [2])。
如何在 PyTorch 中使用 SWA?
在 `torch.optim.swa_utils` 中,我们实现了所有 SWA 组件,以便方便地将 SWA 与任何模型一起使用。特别是,我们实现了用于 SWA 模型的 `AveragedModel` 类、`SWALR` 学习率调度器和 `update_bn` 实用函数,用于在训练结束时更新 SWA 批归一化统计信息。
在下面的示例中,`swa_model` 是累积权重平均值的 SWA 模型。我们总共训练模型 300 个 epoch,并在第 160 个 epoch 切换到 SWA 学习率调度并开始收集参数的 SWA 平均值。
from torch.optim.swa_utils import AveragedModel, SWALR
from torch.optim.lr_scheduler import CosineAnnealingLR
loader, optimizer, model, loss_fn = ...
swa_model = AveragedModel(model)
scheduler = CosineAnnealingLR(optimizer, T_max=100)
swa_start = 5
swa_scheduler = SWALR(optimizer, swa_lr=0.05)
for epoch in range(100):
for input, target in loader:
optimizer.zero_grad()
loss_fn(model(input), target).backward()
optimizer.step()
if epoch > swa_start:
swa_model.update_parameters(model)
swa_scheduler.step()
else:
scheduler.step()
# Update bn statistics for the swa_model at the end
torch.optim.swa_utils.update_bn(loader, swa_model)
# Use swa_model to make predictions on test data
preds = swa_model(test_input)
接下来,我们将详细解释 `torch.optim.swa_utils` 的每个组件。
`AveragedModel` 类用于计算 SWA 模型的权重。您可以通过运行 `swa_model = AveragedModel(model)` 来创建一个平均模型。然后,您可以通过 `swa_model.update_parameters(model)` 更新平均模型的参数。默认情况下,`AveragedModel` 计算您提供的参数的运行相等平均值,但您也可以使用 `avg_fn` 参数使用自定义平均函数。在以下示例中,`ema_model` 计算指数移动平均值。
ema_avg = lambda averaged_model_parameter, model_parameter, num_averaged:\
0.1 * averaged_model_parameter + 0.9 * model_parameter
ema_model = torch.optim.swa_utils.AveragedModel(model, avg_fn=ema_avg)
在实践中,我们发现图 2 中采用修改后的学习率调度并进行相等平均可提供最佳性能。
`SWALR` 是一个学习率调度器,它将学习率退火到一个固定值,然后保持不变。例如,以下代码创建了一个调度器,它在线性退火学习率,从其初始值在每个参数组中经过 `5` 个 epoch 到 `0.05`。
swa_scheduler = torch.optim.swa_utils.SWALR(optimizer,
anneal_strategy="linear", anneal_epochs=5, swa_lr=0.05)
我们还实现了余弦退火到固定值(`anneal_strategy="cos"`)。在实践中,我们通常在 epoch `swa_start` (例如,在训练 epoch 的 75% 之后)切换到 `SWALR`,并同时开始计算权重的运行平均值
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)
swa_start = 75
for epoch in range(100):
# <train epoch>
if i > swa_start:
swa_model.update_parameters(model)
swa_scheduler.step()
else:
scheduler.step()
最后,`update_bn` 是一个实用函数,用于在给定的数据加载器 `loader` 上计算 SWA 模型的批归一化统计信息
torch.optim.swa_utils.update_bn(loader, swa_model)
`update_bn` 将 `swa_model` 应用于数据加载器中的每个元素,并计算模型中每个批归一化层的激活统计信息。
一旦您计算了 SWA 平均值并更新了批归一化层,您就可以应用 `swa_model` 来对测试数据进行预测。
它为什么有效?
损失曲面存在大片平坦区域 [9]。在下面的图 3 中,我们展示了参数空间的一个子空间中损失曲面的可视化,该子空间包含连接两个独立训练的 SGD 解的路径,使得沿该路径的每个点的损失都相似地低。SGD 收敛到这些区域的边界附近,因为几乎没有梯度信号可以移动到内部,因为该区域中的所有点的损失值都相似地低。通过增加学习率,SWA 在这个平坦区域周围“旋转”,然后通过平均迭代,向平坦区域的中心移动。

图 3:ResNet-20 在 CIFAR-10 数据集上无跳过连接的模式连接可视化。该可视化是与 Javier Ideami (https://losslandscape.com/) 合作创建的。有关更多详细信息,请参阅此博文。
我们期望位于损失的平坦区域中心的解比靠近边界的解具有更好的泛化能力。实际上,训练误差和测试误差曲面在权重空间中并非完全对齐。位于平坦区域中心的解对训练误差和测试误差曲面之间的偏移不那么敏感,而靠近边界的解则更易受影响。在下面的图 4 中,我们展示了连接 SWA 和 SGD 解的方向上的训练损失和测试误差曲面。如您所见,虽然 SWA 解的训练损失高于 SGD 解,但它位于低损失区域的中心,并且具有明显更好的测试误差。

图 4。连接 SWA 解(圆圈)和 SGD 解(方块)的直线上的训练损失和测试误差。SWA 解位于低训练损失的宽阔区域中心,而 SGD 解位于边界附近。由于训练损失和测试误差曲面之间的偏移,SWA 解带来了更好的泛化能力。
SWA 取得了哪些成果?
我们发布了一个 GitHub 仓库,其中包含使用 PyTorch 实现 SWA 训练 DNN 的示例。例如,这些示例可用于在 CIFAR-100 上取得以下结果
VGG-16 | ResNet-164 | WideResNet-28×10 | |
---|---|---|---|
SGD | 72.8 ± 0.3 | 78.4 ± 0.3 | 81.0 ± 0.3 |
SWA | 74.4 ± 0.3 | 79.8 ± 0.4 | 82.5 ± 0.2 |
半监督学习
在后续的论文中,SWA 被应用于半监督学习,并在多种设置中改进了最佳报告结果 [2]。例如,如果只有 4k 训练数据点的训练标签,使用 SWA 可以在 CIFAR-10 上获得 95% 的准确率(此问题之前的最佳报告结果是 93.7%)。这篇论文还探讨了在 epoch 内多次平均,这可以加速收敛并在给定时间内找到更平坦的解。

图 5。快速 SWA 在 CIFAR-10 半监督学习中的性能。快速 SWA 在所有考虑的设置中都取得了创纪录的结果。
强化学习
在另一篇后续论文中,SWA 被证明可以提高策略梯度方法 A2C 和 DDPG 在多款 Atari 游戏和 MuJoCo 环境中的性能 [3]。此应用也是 SWA 与 Adam 结合使用的实例。请记住,SWA 不仅限于 SGD,并且基本上可以使任何优化器受益。
环境名称 | A2C | A2C + SWA |
---|---|---|
突破 | 522 ± 34 | 703 ± 60 |
Qbert | 18777 ± 778 | 21272 ± 655 |
太空入侵者 | 7727 ± 1121 | 21676 ± 8897 |
海中救援 | 1779 ± 4 | 1795 ± 4 |
光束骑士 | 9999 ± 402 | 11321 ± 1065 |
疯狂攀爬者 | 147030 ± 10239 | 139752 ± 11618 |
低精度训练
我们可以通过将向下舍入的权重与向上舍入的权重相结合来过滤量化噪声。此外,通过平均权重以找到损失曲面的平坦区域,权重的大的扰动不会影响解决方案的质量(图 9 和图 10)。最近的工作表明,通过将 SWA 适应低精度设置(一种称为 SWALP 的方法),即使所有训练都在 8 位中进行,也可以匹配全精度 SGD 的性能 [5]。这是一个非常实际且重要的结果,因为 (1) 8 位 SGD 训练的性能明显不如全精度 SGD,并且 (2) 低精度训练比训练后低精度预测(通常设置)困难得多。例如,在 CIFAR-100 上使用浮点(16 位)SGD 训练的 ResNet-164 达到 22.2% 的误差,而 8 位 SGD 达到 24.0% 的误差。相比之下,使用 8 位训练的 SWALP 达到 21.8% 的误差。

图 9。量化解决方案会导致权重扰动,对尖锐解决方案(左)的质量影响大于对宽泛解决方案(右)的质量影响。

图 10。标准低精度训练与 SWALP 的区别。
另一项工作 SQWA 提出了一种在低精度下量化和微调神经网络的方法 [12]。特别是,SQWA 在 CIFAR-100 和 ImageNet 上实现了量化到 2 位的 DNN 的最先进结果。
校准与不确定性估计
通过在损失中找到一个中心解,SWA 还可以改进校准和不确定性表示。事实上,SWA 可以看作是集成的近似,类似于贝叶斯模型平均,但只有一个模型 [1]。
SWA 可以看作是在修改后的学习率调度下取 SGD 迭代的一阶矩。我们可以通过同时取迭代的二阶矩来直接推广 SWA,以形成权重上的高斯近似后验,进一步利用 SGD 迭代来刻画损失几何。这种方法,SWA-Gaussian (SWAG) 是一种简单、可扩展且方便的贝叶斯深度学习中不确定性估计和校准方法 [4]。SWAG 分布近似于真实后验的形状:下面的图 6 展示了 ResNet-20 在 CIFAR-10 上的 SWAG 分布和后验对数密度。

图 6。SWAG 后验近似和 ResNet-20 在 CIFAR-10 上无跳过连接的损失曲面,该损失曲面位于由 SWAG 协方差矩阵的两个最大特征值形成的子空间中。SWAG 分布的形状与后验一致:两个分布的峰值重合,并且两个分布在一个方向上都比正交方向更宽。可视化由 Javier Ideami 合作创建。
从经验来看,SWAG 在不确定性量化、分布外检测、校准和计算机视觉任务中的迁移学习方面,表现与包括 MC dropout、KFAC Laplace 和温度缩放等流行替代方案相当或更好。SWAG 的代码可在此处获取。

图 7。MultiSWAG 概括了 SWAG 和深度集成,对多个吸引盆进行贝叶斯模型平均,从而显著提高了性能。相比之下,如图所示,深度集成选择了不同的模式,而标准变分推断 (VI) 在单个盆中进行边缘化(模型平均)。
MultiSWAG [9] 使用多个独立的 SWAG 模型来形成高斯混合作为近似后验分布。不同的吸引盆包含对数据高度互补的解释。因此,对这些多个盆进行边缘化可以显著提高准确性和不确定性表示。MultiSWAG 可以看作是深度集成的泛化,但性能有所提升。
事实上,我们在图 8 中看到,MultiSWAG 完全缓解了双重下降——更灵活的模型具有单调改进的性能——并且比 SGD 提供了显著改进的泛化能力。例如,当 ResNet-18 的层宽度为 20 时,Multi-SWAG 实现了低于 30% 的误差,而 SGD 实现了超过 45% 的误差,差距超过 15%!

图 8。ResNet-18 在 CIFAR-100 上不同宽度下的 SGD、SWAG 和 Multi-SWAG。我们看到 Multi-SWAG 特别缓解了双重下降,并显著提高了 SGD 的准确性。
参考文献 [10] 还考虑了 Multi-SWA,它使用多个独立训练的 SWA 解决方案进行集成,在不增加任何额外计算成本的情况下,提供了比深度集成更好的性能。MultiSWA 和 MultiSWAG 的代码可在此处获取。
另一种方法,子空间推理,围绕 SWA 解构建一个低维子空间,并在该子空间中对权重进行边缘化以近似贝叶斯模型平均 [5]。子空间推理利用 SGD 迭代的统计信息来构建 SWA 解和子空间。该方法在分类和回归问题中均在预测准确性和不确定性校准方面取得了优异的性能。代码可在此处获取。
试一试!
深度学习中最大的开放问题之一是,为什么 SGD 能够找到好的解决方案,因为训练目标是高度多模态的,并且存在许多参数设置,尽管训练损失为零但泛化能力很差。通过理解与泛化相关的几何特征,例如平坦度,我们可以开始解决这些问题并构建提供更好泛化能力和许多其他有用功能(例如不确定性表示)的优化器。我们介绍了 SWA,一个简单的替代标准优化器(如 SGD 和 Adam)的即插即用方法,原则上可以使任何训练深度神经网络的人受益。SWA 已在多个领域表现出强大的性能,包括计算机视觉、半监督学习、强化学习、不确定性表示、校准、贝叶斯模型平均和低精度训练。
我们鼓励您尝试使用 SWA!现在,SWA 的使用与 PyTorch 中的任何标准训练一样简单。即使您已经训练了模型,您也可以通过从预训练模型开始,运行少量 epoch 的 SWA 来显著提高性能。
[1] 权重平均带来更宽的最优值和更好的泛化能力;Pavel Izmailov, Dmitry Podoprikhin, Timur Garipov, Dmitry Vetrov, Andrew Gordon Wilson;人工智能不确定性 (UAI),2018 年。
[2] 无标签数据存在许多一致的解释:为何要平均;Ben Athiwaratkun, Marc Finzi, Pavel Izmailov, Andrew Gordon Wilson;国际学习表示大会 (ICLR),2019 年。
[3] 通过权重平均提高深度强化学习的稳定性;Evgenii Nikishin, Pavel Izmailov, Ben Athiwaratkun, Dmitrii Podoprikhin, Timur Garipov, Pavel Shvechikov, Dmitry Vetrov, Andrew Gordon Wilson;UAI 2018 研讨会:深度学习中的不确定性,2018 年。
[4] 贝叶斯深度学习中不确定性的简单基线 Wesley Maddox, Timur Garipov, Pavel Izmailov, Andrew Gordon Wilson;神经网络信息处理系统 (NeurIPS),2019 年。
[5] 贝叶斯深度学习的子空间推断 Pavel Izmailov, Wesley Maddox, Polina Kirichenko, Timur Garipov, Dmitry Vetrov, Andrew Gordon Wilson 人工智能不确定性 (UAI), 2019。
[6] SWALP:低精度训练中的随机权重平均 Guandao Yang, Tianyi Zhang, Polina Kirichenko, Junwen Bai, Andrew Gordon Wilson, Christopher De Sa;国际机器学习会议 (ICML),2019 年。
[7] David Ruppert. 从缓慢收敛的 Robbins-Monro 过程进行高效估计;技术报告,康奈尔大学运筹学与工业工程,1988 年。
[8] 通过平均加速随机逼近。Boris T Polyak 和 Anatoli B Juditsky;SIAM 控制与优化杂志,30(4):838–855,1992 年。
[9] 损失曲面、模式连接和 DNN 的快速集成 Timur Garipov, Pavel Izmailov, Dmitrii Podoprikhin, Dmitry Vetrov, Andrew Gordon Wilson。神经信息处理系统 (NeurIPS),2018。
[10] 贝叶斯深度学习与泛化的概率视角 Andrew Gordon Wilson, Pavel Izmailov. ArXiv 预印本,2020。
[11] 并行随机权重平均:泛化能力好的大批量训练 Gupta, Vipul, Santiago Akle Serrano, and Dennis DeCoste;国际学习表示会议 (ICLR)。2019 年。
[12] SQWA:随机量化权重平均,用于提高低精度深度神经网络的泛化能力 Shin, Sungho, Yoonho Boo, and Wonyong Sung;arXiv 预印本 2020。