在这篇博文中,我们将介绍最近提出的随机权重平均(Stochastic Weight Averaging, SWA)技术 [1, 2],及其在 torchcontrib 中的新实现。SWA 是一种简单的程序,无需额外成本即可改善深度学习中随机梯度下降(SGD)的泛化能力,并且可以作为 PyTorch 中任何其他优化器的直接替代方案。SWA 具有广泛的应用和功能。
- 事实证明,SWA 显著提高了计算机视觉任务的泛化能力,包括 ImageNet 和 CIFAR 基准测试上的 VGG、ResNets、Wide ResNets 和 DenseNets [1, 2]。
- SWA 在半监督学习和域适应的关键基准上提供了最先进的性能 [2]。
- 研究表明,SWA 提高了深度强化学习中策略梯度方法训练的稳定性以及最终的平均奖励 [3]。
- SWA 的扩展版本可以实现高效的贝叶斯模型平均,以及深度学习中高质量的不确定性估计和校准 [4]。
- 用于低精度训练的 SWA(即 SWALP)即使在所有数值(包括梯度累加器)均量化为 8 位的情况下,也能媲美全精度 SGD 的性能 [5]。
简而言之,SWA 对通过修改学习率调度(见图 1 左侧面板)由 SGD 遍历的权重进行等权重平均。SWA 的解最终位于损失函数平坦区域的中心,而 SGD 往往收敛到低损区域的边界,使其容易受到训练误差曲面和测试误差曲面之间偏移的影响(见图 1 的中间和右侧面板)。

图 1. 在 CIFAR-100 上使用 Preactivation ResNet-164 对 SWA 和 SGD 的说明 [1]。左图:三个 FGE 样本的测试误差曲面及对应的 SWA 解(权重空间平均)。中图和右图:显示由 SGD(收敛时)和 SWA 提出的权重的测试误差和训练损失曲面,两者均从 125 个训练周期后的同一 SGD 初始化点开始。有关这些图表是如何构建的详细信息,请参阅 [1]。
通过我们在 torchcontrib 中的新实现,使用 SWA 和使用 PyTorch 中的任何其他优化器一样简单。
from torchcontrib.optim import SWA
...
...
# training loop
base_opt = torch.optim.SGD(model.parameters(), lr=0.1)
opt = torchcontrib.optim.SWA(base_opt, swa_start=10, swa_freq=5, swa_lr=0.05)
for _ in range(100):
opt.zero_grad()
loss_fn(model(input), target).backward()
opt.step()
opt.swap_swa_sgd()
您可以使用 SWA 类包装 torch.optim 中的任何优化器,然后像往常一样训练模型。训练完成后,只需调用 swap_swa_sgd() 即可将模型的权重设置为它们的 SWA 平均值。下面我们详细解释 SWA 程序和 SWA 类的参数。我们强调,SWA 可以与任何优化程序(例如 Adam)结合使用,就像它与 SGD 结合使用一样。
这仅仅是平均 SGD 吗?
从宏观层面来看,对 SGD 迭代进行平均可以追溯到几十年前的凸优化领域 [6, 7],在那里它有时被称为 Polyak-Ruppert 平均或平均 SGD。但细节决定成败。平均 SGD 通常与衰减的学习率和指数移动平均线结合使用,通常用于凸优化。在凸优化中,重点在于提高收敛速度。在深度学习中,这种形式的平均 SGD 可以平滑 SGD 迭代的轨迹,但效果并无太大差异。
相比之下,SWA 侧重于 SGD 迭代的等权重平均,并结合周期性或较高的恒定学习率,利用深度学习特有的训练目标平坦性 [8] 来改善泛化能力。
随机权重平均
使 SWA 有效的两个重要因素是:首先,SWA 使用了修改后的学习率调度,以便 SGD 继续探索高性能网络集合,而不是简单地收敛到一个单一的解。例如,我们可以对前 75% 的训练时间使用标准的衰减学习率策略,然后在剩下的 25% 的时间内将学习率设置为一个相当高的恒定值(见下图 2)。第二个因素是对 SGD 遍历的网络权重进行平均。例如,我们可以维护训练最后 25% 时间内每个周期结束时所获得的权重的运行平均值(见图 2)。

图 2. SWA 采用的学习率调度说明。前 75% 的训练使用标准衰减调度,剩余 25% 使用较高的恒定值。SWA 平均值是在训练的最后 25% 期间形成的。
在我们的实现中,SWA 优化器的自动模式允许我们运行上述程序。要以自动模式运行 SWA,只需使用 SWA(base_opt, swa_start, swa_freq, swa_lr) 包装您选择的优化器 base_opt(可以是 SGD、Adam 或任何其他 torch.optim.Optimizer)。在 swa_start 个优化步骤后,学习率将切换为一个恒定值 swa_lr,并且在每 swa_freq 个优化步骤结束时,权重的一个快照将被添加到 SWA 运行平均值中。一旦运行 opt.swap_swa_sgd(),模型的权重就会被它们的 SWA 运行平均值所取代。
批归一化(Batch Normalization)
需要牢记的一个重要细节是批归一化。批归一化层在训练期间计算激活的运行统计数据。请注意,SWA 的权重平均值在训练期间从不用于进行预测,因此在您使用 opt.swap_swa_sgd() 重置模型权重后,批归一化层并没有计算过激活统计数据。要计算激活统计数据,您只需在训练完成后使用 SWA 模型对训练数据进行一次前向传播即可。在 SWA 类中,我们提供了一个辅助函数 opt.bn_update(train_loader, model)。它通过在 train_loader 数据加载器上进行前向传播,更新模型中每个批归一化层的激活统计数据。您只需要在训练结束时调用此函数一次。
高级学习率调度
SWA 可以与任何鼓励探索平坦解区域的学习率调度结合使用。例如,您可以在训练时间的最后 25% 中使用周期性学习率,而不是恒定值,并在每个周期内对对应于最低学习率值的网络权重进行平均(见图 3)。

图 3. 使用替代学习率调度的 SWA 说明。最后 25% 的训练采用周期性学习率,并在每个周期结束时收集用于平均的模型。
在我们的实现中,您可以通过在手动模式下使用 SWA 来实现自定义的学习率和权重平均策略。以下代码等同于本博文开头介绍的自动模式代码。
opt = torchcontrib.optim.SWA(base_opt)
for i in range(100):
opt.zero_grad()
loss_fn(model(input), target).backward()
opt.step()
if i > 10 and i % 5 == 0:
opt.update_swa()
opt.swap_swa_sgd()
在手动模式下,您无需指定 swa_start、swa_lr 和 swa_freq,只需在您想要更新 SWA 运行平均值时调用 opt.update_swa()(例如在每个学习率周期结束时)。在手动模式下,SWA 不会更改学习率,因此您可以像使用任何其他 torch.optim.Optimizer 一样使用任何您想要的调度。
为什么它有效?
SGD 收敛到损失函数的宽平区域内的某个解。权重空间是极高维的,平坦区域的大部分体积集中在边界附近,因此 SGD 的解总是位于损失平坦区域的边界附近。另一方面,SWA 对多个 SGD 解进行平均,这使其能够向平坦区域的中心移动。
我们期望位于损失平坦区域中心的解比那些靠近边界的解具有更好的泛化能力。实际上,训练误差曲面和测试误差曲面在权重空间中并不完全对齐。位于平坦区域中心的解不像靠近边界的解那样容易受到训练和测试误差曲面之间偏移的影响。在下方的图 4 中,我们展示了沿连接 SWA 和 SGD 解方向的训练损失和测试误差曲面。正如您所见,虽然 SWA 解与 SGD 解相比具有更高的训练损失,但它位于低损区域的中心,并且具有明显更好的测试误差。

图 4. 沿连接 SWA 解(圆圈)和 SGD 解(方块)直线的训练损失和测试误差。SWA 解位于宽阔的低训练损失区域中心,而 SGD 解位于边界附近。由于训练损失和测试误差曲面之间的偏移,SWA 解带来了更好的泛化能力。
示例和结果
我们发布了一个 GitHub 仓库 此处,其中包含使用 torchcontrib 实现的 SWA 来训练深度神经网络(DNN)的示例。例如,这些示例可用于在 CIFAR-100 上取得以下结果:
| DNN(预算) | SGD | SWA 1 预算 | SWA 1.25 预算 | SWA 1.5 预算 |
|---|---|---|---|---|
| VGG16 (200) | 72.55 ± 0.10 | 73.91 ± 0.12 | 74.17 ± 0.15 | 74.27 ± 0.25 |
| PreResNet110 (150) | 76.77 ± 0.38 | 78.75 ± 0.16 | 78.91 ± 0.29 | 79.10 ± 0.21 |
| PreResNet164 (150) | 78.49 ± 0.36 | 79.77 ± 0.17 | 80.18 ± 0.23 | 80.35 ± 0.16 |
| WideResNet28x10 (200) | 80.82 ± 0.23 | 81.46 ± 0.23 | 81.91 ± 0.27 | 82.15 ± 0.27 |
半监督学习
在后续的一篇 论文 中,SWA 被应用于半监督学习,在多种设置下展现出超越此前最佳报告结果的改进。例如,如果您只有 4k 个训练数据点的标签,使用 SWA 可以在 CIFAR-10 上获得 95% 的准确率(该问题此前报告的最佳结果为 93.7%)。该论文还探索了在周期内多次平均,这可以加速收敛并在给定时间内找到更平坦的解。

图 5. fast-SWA 在 CIFAR-10 半监督学习中的表现。fast-SWA 在所考虑的每种设置下都取得了创纪录的成绩。
校准和不确定性估计
SWA-Gaussian (SWAG) 是一种用于贝叶斯深度学习中不确定性估计和校准的简单、可扩展且方便的方法。与维护 SGD 迭代运行平均值的 SWA 类似,SWAG 估计迭代的一阶和二阶矩,以构建权重上的高斯分布。SWAG 分布近似了真实后验分布的形状:下方的图 6 展示了 PreResNet-164 在 CIFAR-100 上的后验对数密度上的 SWAG 分布。

图 6. PreResNet-164 在 CIFAR-100 上后验对数密度上的 SWAG 分布。SWAG 分布的形状与后验分布对齐。
从经验来看,SWAG 在计算机视觉任务中的不确定性量化、分布外检测、校准和迁移学习方面,表现与包括 MC Dropout、KFAC Laplace 和温度缩放等流行替代方法相当或更好。SWAG 的代码可在此处获取:https://github.com/wjmaddox/swa_gaussian。
强化学习
在另一篇后续 论文 中,SWA 被证明可以提高几种 Atari 游戏和 MuJoCo 环境中策略梯度方法 A2C 和 DDPG 的性能。
| 环境 | A2C | A2C + SWA |
|---|---|---|
| 突破 | 522 ± 34 | 703 ± 60 |
| Qbert | 18777 ± 778 | 21272 ± 655 |
| 太空入侵者 | 7727 ± 1121 | 21676 ± 8897 |
| 海中救援 | 1779 ± 4 | 1795 ± 4 |
| 疯狂攀爬者 | 147030 ± 10239 | 139752 ± 11618 |
| 光束骑士 | 9999 ± 402 | 11321 ± 1065 |
低精度训练
我们可以通过将向下舍入的权重与向上舍入的权重相结合来过滤量化噪声。此外,通过平均权重以找到损失曲面的平坦区域,权重的大扰动不会影响解的质量(图 7 和 8)。最近的研究表明,通过将 SWA 调整为低精度设置,在一种称为 SWALP 的方法中,即使所有训练均在 8 位下进行,也可以媲美全精度 SGD 的性能 [5]。这是一个在实践中非常重要的结果,因为(1)8 位 SGD 训练的性能明显低于全精度 SGD,(2)低精度训练比训练后的低精度预测(常见设置)要困难得多。例如,在 CIFAR-100 上用浮点(16 位)SGD 训练的 ResNet-164 实现了 22.2% 的误差,而 8 位 SGD 实现了 24.0% 的误差。相比之下,8 位训练的 SWALP 实现了 21.8% 的误差。

图 7. 在平坦区域进行量化仍然可以提供低损失的解。

图 8. 低精度 SGD 训练(使用修改后的学习率调度)和 SWALP。
结论
深度学习中最大的未决问题之一是,鉴于训练目标是高度多模态的,并且原则上有许多参数设置可以达到零训练损失但泛化能力较差,SGD 究竟为何能找到好的解。通过理解与泛化相关的平坦性等几何特征,我们可以开始解决这些问题,并构建提供更好泛化能力的优化器,以及许多其他有用的功能,如不确定性表示。我们介绍了 SWA,这是一种标准 SGD 的简单直接替代品,原则上可以使任何训练深度神经网络的人受益。SWA 已被证明在多个领域具有强大的性能,包括计算机视觉、半监督学习、强化学习、不确定性表示、校准、贝叶斯模型平均和低精度训练。
我们鼓励您尝试使用 SWA!现在,使用 SWA 和使用 PyTorch 中的任何其他优化器一样简单。即使您已经使用 SGD(或其他任何优化器)训练了模型,通过使用预训练模型运行少量的 SWA 周期,也很容易实现 SWA 的优势。
[1] Averaging Weights Leads to Wider Optima and Better Generalization; Pavel Izmailov, Dmitry Podoprikhin, Timur Garipov, Dmitry Vetrov, Andrew Gordon Wilson; Uncertainty in Artificial Intelligence (UAI), 2018
[2] There Are Many Consistent Explanations of Unlabeled Data: Why You Should Average; Ben Athiwaratkun, Marc Finzi, Pavel Izmailov, Andrew Gordon Wilson; International Conference on Learning Representations (ICLR), 2019
[3] Improving Stability in Deep Reinforcement Learning with Weight Averaging; Evgenii Nikishin, Pavel Izmailov, Ben Athiwaratkun, Dmitrii Podoprikhin, Timur Garipov, Pavel Shvechikov, Dmitry Vetrov, Andrew Gordon Wilson, UAI 2018 Workshop: Uncertainty in Deep Learning, 2018
[4] A Simple Baseline for Bayesian Uncertainty in Deep Learning, Wesley Maddox, Timur Garipov, Pavel Izmailov, Andrew Gordon Wilson, arXiv pre-print, 2019: https://arxiv.org/abs/1902.02476
[5] SWALP : Stochastic Weight Averaging in Low Precision Training, Guandao Yang, Tianyi Zhang, Polina Kirichenko, Junwen Bai, Andrew Gordon Wilson, Christopher De Sa, To appear at the International Conference on Machine Learning (ICML), 2019.
[6] David Ruppert. Efficient estimations from a slowly convergent Robbins-Monro process. Technical report, Cornell University Operations Research and Industrial Engineering, 1988.
[7] Acceleration of stochastic approximation by averaging. Boris T Polyak and Anatoli B Juditsky. SIAM Journal on Control and Optimization, 30(4):838–855, 1992.
[8] Loss Surfaces, Mode Connectivity, and Fast Ensembling of DNNs, Timur Garipov, Pavel Izmailov, Dmitrii Podoprikhin, Dmitry Vetrov, Andrew Gordon Wilson. Neural Information Processing Systems (NeurIPS), 2018