作者:Pavel Izmailov 和 Andrew Gordon Wilson

在这篇博文中,我们将介绍最近提出的随机权重平均法(Stochastic Weight Averaging,简称 SWA)技术 [1, 2],及其在 torchcontrib 中的新实现。SWA 是一种简单的程序,与随机梯度下降(SGD)相比,它无需额外成本即可提高深度学习的泛化能力,并且可以作为 PyTorch 中任何其他优化器的直接替代品使用。SWA 具有广泛的应用和特性:

  1. 研究表明,SWA 显著提高了计算机视觉任务的泛化能力,包括 ImageNet 和 CIFAR 基准上的 VGG、ResNets、Wide ResNets 和 DenseNets [1, 2]。
  2. SWA 在半监督学习和域适应的关键基准上提供了最先进的性能 [2]。
  3. 研究表明,SWA 提高了深度强化学习中策略梯度方法的训练稳定性以及最终的平均奖励 [3]。
  4. SWA 的一个扩展可以实现高效的贝叶斯模型平均,并在深度学习中获得高质量的不确定性估计和校准 [4]。
  5. 用于低精度训练的 SWA,即 SWALP,即使将所有数字(包括梯度累加器)量化到 8 位,也能达到全精度 SGD 的性能 [5]。

简而言之,SWA 对采用修正学习率计划的 SGD 遍历的权重执行等权平均(参见图 1 的左图)。SWA 解最终位于损失函数宽平区域的中心,而 SGD 倾向于收敛到低损失区域的边界,使其易受训练和测试误差曲面之间偏移的影响(参见图 1 的中图和右图)。

图 1. SWA 和 SGD 在 CIFAR-100 上使用 Preactivation ResNet-164 的图示 [1]。左图:三个 FGE 样本及其对应的 SWA 解(在权重空间中进行平均)的测试误差曲面。中图右图:测试误差和训练损失曲面,显示了从 SGD 训练 125 个 epoch 后的相同初始化开始,由 SGD(收敛时)和 SWA 提出的权重。有关这些图构建的详细信息,请参阅 [1]。

使用我们在 torchcontrib 中的新实现,使用 SWA 就像使用 PyTorch 中的任何其他优化器一样简单

from torchcontrib.optim import SWA

...
...

# training loop
base_opt = torch.optim.SGD(model.parameters(), lr=0.1)
opt = torchcontrib.optim.SWA(base_opt, swa_start=10, swa_freq=5, swa_lr=0.05)
for _ in range(100):
     opt.zero_grad()
     loss_fn(model(input), target).backward()
     opt.step()
opt.swap_swa_sgd()

您可以使用 SWA 类包装来自 torch.optim 的任何优化器,然后像往常一样训练模型。训练完成后,只需调用 swap_swa_sgd() 将模型的权重设置为其 SWA 平均值。下面我们将详细解释 SWA 程序和 SWA 类的参数。我们强调,SWA 可以与任何优化过程(例如 Adam)结合使用,就像它可以与 SGD 结合使用一样。

这仅仅是平均 SGD 吗?

从概念上看,对 SGD 迭代进行平均可以追溯到凸优化领域的几十年前 [6, 7],有时被称为 Polyak-Ruppert 平均或平均 SGD。但细节决定成败平均 SGD 通常与衰减学习率和指数移动平均结合使用,通常用于凸优化。在凸优化中,重点在于提高收敛速度。在深度学习中,这种形式的平均 SGD 平滑了 SGD 迭代的轨迹,但表现差异不大。

相比之下,SWA 专注于对采用修正的周期性或高常数学习率的 SGD 迭代进行等权平均,并利用深度学习特有的训练目标平坦性 [8] 来提高泛化能力

随机权重平均法

有两个重要的组成部分使得 SWA 起作用。首先,SWA 使用修正的学习率计划,以便 SGD 继续探索高性能网络集合,而不是简单地收敛到单个解。例如,我们可以在前 75% 的训练时间使用标准的衰减学习率策略,然后在剩余的 25% 时间内将学习率设置为一个合理的常数值(参见下面的图 2)。第二个组成部分是对 SGD 遍历的网络的权重进行平均。例如,我们可以在训练时间的最后 25% 的每个 epoch 结束时维护获得的权重的运行平均值(参见图 2)。

图 2. SWA 采用的学习率计划图示。前 75% 的训练时间使用标准衰减计划,然后剩余 25% 使用高常数值。SWA 平均值在训练的最后 25% 期间形成。

在我们的实现中,SWA 优化器的自动模式允许我们运行上述过程。要在自动模式下运行 SWA,您只需使用 SWA(base_opt, swa_start, swa_freq, swa_lr) 包装您选择的优化器 base_opt(可以是 SGD、Adam 或任何其他 torch.optim.Optimizer)。在 swa_start 次优化步骤后,学习率将切换到常数值 swa_lr,并且在每 swa_freq 次优化步骤结束时,权重的快照将被添加到 SWA 运行平均值中。一旦您运行 opt.swap_swa_sgd(),模型的权重将被其 SWA 运行平均值替换。

批标准化

一个需要记住的重要细节是批标准化(Batch Normalization)。批标准化层在训练期间计算激活的运行统计量。请注意,SWA 平均权重在训练期间从未用于进行预测,因此在您使用 opt.swap_swa_sgd() 重置模型权重后,批标准化层没有计算激活统计量。要计算激活统计量,只需在训练完成后使用 SWA 模型对训练数据进行一次前向传播。在 SWA 类中,我们提供了一个辅助函数 opt.bn_update(train_loader, model)。它通过对 train_loader 数据加载器中的数据进行前向传播,更新模型中每个批标准化层的激活统计量。您只需在训练结束时调用一次此函数。

高级学习率计划

SWA 可以与任何鼓励探索解的平坦区域的学习率计划结合使用。例如,您可以在训练时间的最后 25% 使用周期性学习率而不是常数值,并在每个周期内学习率最低值的网络权重进行平均(参见图 3)。

图 3. 使用替代学习率计划的 SWA 图示。训练的最后 25% 采用周期性学习率,并在每个周期结束时收集用于平均的模型。

在我们的实现中,您可以在手动模式下使用 SWA 实现自定义学习率和权重平均策略。以下代码等同于本博文开头介绍的自动模式代码。

opt = torchcontrib.optim.SWA(base_opt)
for i in range(100):
    opt.zero_grad()
    loss_fn(model(input), target).backward()
    opt.step()
    if i > 10 and i % 5 == 0:
        opt.update_swa()
opt.swap_swa_sgd()

在手动模式下,您无需指定 swa_startswa_lrswa_freq,只需在您想要更新 SWA 运行平均值时调用 opt.update_swa()(例如在每个学习率周期结束时)。在手动模式下,SWA 不会改变学习率,因此您可以像使用任何其他 torch.optim.Optimizer 那样使用您想要的任何计划。

它为什么有效?

SGD 收敛到损失函数的宽平区域内的一个解。权重空间是极其高维的,平坦区域的大部分体积集中在边界附近,因此 SGD 解总是出现在低损失区域的边界附近。另一方面,SWA 平均多个 SGD 解,这使得它能够移向平坦区域的中心。

我们预期位于损失函数平坦区域中心的解比靠近边界的解具有更好的泛化能力。实际上,训练和测试误差曲面在权重空间中并非完美对齐。位于平坦区域中心的解不如靠近边界的解容易受到训练和测试误差曲面之间偏移的影响。在下面的图 4 中,我们展示了沿连接 SWA 和 SGD 解方向的训练损失和测试误差曲面。正如您所见,虽然 SWA 解的训练损失高于 SGD 解,但它位于低损失区域的中心,并具有明显更好的测试误差。

图 4. 沿连接 SWA 解(圆圈)和 SGD 解(方块)的直线上的训练损失和测试误差。SWA 解位于宽阔的低训练损失区域中心,而 SGD 解靠近边界。由于训练损失和测试误差曲面之间的偏移,SWA 解带来了更好的泛化能力。

示例和结果

我们在此处发布了一个 GitHub 仓库 here,其中包含使用 torchcontrib SWA 实现训练 DNN 的示例。例如,这些示例可用于在 CIFAR-100 上获得以下结果

DNN(预算) SGD SWA 1 预算 SWA 1.25 预算 SWA 1.5 预算
VGG16 (200) 72.55 ± 0.10 73.91 ± 0.12 74.17 ± 0.15 74.27 ± 0.25
PreResNet110 (150) 76.77 ± 0.38 78.75 ± 0.16 78.91 ± 0.29 79.10 ± 0.21
PreResNet164 (150) 78.49 ± 0.36 79.77 ± 0.17 80.18 ± 0.23 80.35 ± 0.16
WideResNet28x10 (200) 80.82 ± 0.23 81.46 ± 0.23 81.91 ± 0.27 82.15 ± 0.27

半监督学习

在后续的 论文 中,SWA 被应用于半监督学习,并在多种设置中展示了超越现有最佳结果的改进。例如,如果您仅有 4k 训练数据点的标签,使用 SWA 可以在 CIFAR-10 上获得 95% 的准确率(先前在此问题上的最佳报告结果为 93.7%)。该论文还探讨了在 epoch 内多次平均,这可以加速收敛并在给定时间内找到更平坦的解。

图 5. fast-SWA 在 CIFAR-10 半监督学习上的性能。fast-SWA 在所有考虑的设置中都取得了创纪录的结果。

校准与不确定性估计

SWA-Gaussian (SWAG) 是一种简单、可扩展且方便的贝叶斯深度学习不确定性估计和校准方法。与 SWA 维护 SGD 迭代的运行平均值类似,SWAG 估计迭代的一阶和二阶矩,以构建权重上的高斯分布。SWAG 分布近似于真实后验的形状:下面的图 6 显示了 CIFAR-100 上 PreResNet-164 后验对数密度之上的 SWAG 分布。

图 6. CIFAR-100 上 PreResNet-164 后验对数密度之上的 SWAG 分布。SWAG 分布的形状与后验对齐。

从经验上看,在计算机视觉任务中的不确定性量化、分布外检测、校准和迁移学习方面,SWAG 的性能与 MC dropout、KFAC Laplace 和温度缩放等流行替代方法相当或更好。SWAG 的代码可在此处获取:here

强化学习

在另一篇后续 论文 中,研究表明 SWA 提高了策略梯度方法 A2C 和 DDPG 在多个 Atari 游戏和 MuJoCo 环境中的性能。

环境 A2C A2C + SWA
Breakout 522 ± 34 703 ± 60
Qbert 18777 ± 778 21272 ± 655
SpaceInvaders 7727 ± 1121 21676 ± 8897
Seaquest 1779 ± 4 1795 ± 4
CrazyClimber 147030 ± 10239 139752 ± 11618
BeamRider 9999 ± 402 11321 ± 1065

低精度训练

我们可以通过结合向下舍入的权重和向上舍入的权重来过滤量化噪声。此外,通过平均权重找到损失曲面的平坦区域,权重的较大扰动不会影响解的质量(图 7 和图 8)。最近的工作表明,通过将 SWA 调整到低精度设置,一种称为 SWALP 的方法,可以达到即使在 8 位训练下也与全精度 SGD 相同的性能 [5]。考虑到 (1) 8 位 SGD 训练的性能显著低于全精度 SGD,以及 (2) 低精度训练远比训练后的低精度预测(通常设置)困难得多,这是一个非常重要的实际结果。例如,在 CIFAR-100 上使用 float (16 位) SGD 训练的 ResNet-164 达到 22.2% 的误差,而 8 位 SGD 达到 24.0% 的误差。相比之下,使用 8 位训练的 SWALP 达到 21.8% 的误差。

图 7. 在平坦区域进行量化仍然可以提供低损失的解。

图 8. 低精度 SGD 训练(采用修正学习率计划)和 SWALP。

结论

深度学习中最大的未决问题之一是,为什么 SGD 能够找到好的解,考虑到训练目标是高度多模态的,原则上有很多参数设置可以实现零训练损失但泛化能力很差。通过理解诸如平坦度等与泛化相关的几何特征,我们可以开始解决这些问题并构建能够提供更好泛化能力的优化器,以及许多其他有用的功能,例如不确定性表示。我们介绍了 SWA,它是标准 SGD 的简单直接替代品,原则上可以使任何训练深度神经网络的人受益。SWA 已被证明在计算机视觉、半监督学习、强化学习、不确定性表示、校准、贝叶斯模型平均和低精度训练等多个领域具有强大的性能。

我们鼓励您尝试 SWA!现在使用 SWA 就像使用 PyTorch 中的任何其他优化器一样简单。即使您已经使用 SGD(或任何其他优化器)训练了模型,通过使用预训练模型开始,运行少量 epoch 的 SWA 也很容易实现 SWA 的好处。

  • [1] 平均权重带来更宽的局部最优和更好的泛化能力;Pavel Izmailov, Dmitry Podoprikhin, Timur Garipov, Dmitry Vetrov, Andrew Gordon Wilson;人工智能不确定性会议 (UAI), 2018
  • [2] 未标记数据有很多一致的解释:为什么您应该进行平均;Ben Athiwaratkun, Marc Finzi, Pavel Izmailov, Andrew Gordon Wilson;国际学习表示会议 (ICLR), 2019
  • [3] 通过权重平均提高深度强化学习的稳定性;Evgenii Nikishin, Pavel Izmailov, Ben Athiwaratkun, Dmitrii Podoprikhin, Timur Garipov, Pavel Shvechikov, Dmitry Vetrov, Andrew Gordon Wilson, UAI 2018 研讨会:深度学习中的不确定性, 2018
  • [4] 贝叶斯深度学习中不确定性的简单基线;Wesley Maddox, Timur Garipov, Pavel Izmailov, Andrew Gordon Wilson;arXiv 预印本, 2019:https://arxiv.org/abs/1902.02476
  • [5] SWALP:低精度训练中的随机权重平均法;Guandao Yang, Tianyi Zhang, Polina Kirichenko, Junwen Bai, Andrew Gordon Wilson, Christopher De Sa;即将发表于国际机器学习大会 (ICML), 2019。
  • [6] David Ruppert. 从缓慢收敛的 Robbins-Monro 过程进行高效估计。技术报告,康奈尔大学运筹学和工业工程系,1988 年。
  • [7] 通过平均加速随机逼近。Boris T Polyak 和 Anatoli B Juditsky. SIAM Journal on Control and Optimization, 30(4):838–855, 1992。
  • [8] 损失曲面、模式连接和 DNN 的快速集成;Timur Garipov, Pavel Izmailov, Dmitrii Podoprikhin, Dmitry Vetrov, Andrew Gordon Wilson. 神经信息处理系统会议 (NeurIPS), 2018