在这篇博文中,我们将介绍最近提出的随机权重平均 (SWA) 技术 [1, 2],以及其在 torchcontrib
中的新实现。SWA 是一种简单的程序,可以在深度学习中改进随机梯度下降 (SGD) 的泛化能力,且无需额外成本,并且可以作为 PyTorch 中任何其他优化器的直接替代品。SWA 具有广泛的应用和功能
- SWA 已被证明可以显着提高计算机视觉任务中的泛化能力,包括 ImageNet 和 CIFAR 基准测试上的 VGG、ResNets、Wide ResNets 和 DenseNets [1, 2]。
- SWA 在半监督学习和领域自适应的关键基准测试中提供了最先进的性能 [2]。
- SWA 已被证明可以提高深度强化学习中策略梯度方法的训练稳定性和最终平均奖励 [3]。
- SWA 的扩展可以获得高效的贝叶斯模型平均,以及深度学习中的高质量不确定性估计和校准 [4]。
- 用于低精度训练的 SWA (SWALP) 即使在所有数字(包括梯度累加器)量化到 8 位的情况下,也可以匹敌全精度 SGD 的性能 [5]。
简而言之,SWA 对 SGD 遍历的权重进行等权重平均,并采用修改后的学习率计划(参见图 1 的左侧面板)。SWA 解决方案最终位于低损耗区域的中心宽阔平坦区域,而 SGD 往往会收敛到低损耗区域的边界,使其容易受到训练和测试误差面之间偏移的影响(参见图 1 的中间和右侧面板)。

图 1. CIFAR-100 上 Preactivation ResNet-164 的 SWA 和 SGD 说明 [1]。左图:三个 FGE 样本的测试误差面和相应的 SWA 解决方案(权重空间中的平均)。中图和右图:测试误差和训练损失面,显示 SGD(在收敛时)和 SWA 提出的权重,从 125 个训练周期后 SGD 的相同初始化开始。有关如何构建这些图的详细信息,请参见 [1]。
借助我们在 torchcontrib 中的新实现,使用 SWA 就像使用 PyTorch 中的任何其他优化器一样简单
from torchcontrib.optim import SWA
...
...
# training loop
base_opt = torch.optim.SGD(model.parameters(), lr=0.1)
opt = torchcontrib.optim.SWA(base_opt, swa_start=10, swa_freq=5, swa_lr=0.05)
for _ in range(100):
opt.zero_grad()
loss_fn(model(input), target).backward()
opt.step()
opt.swap_swa_sgd()
您可以使用 SWA
类包装来自 torch.optim
的任何优化器,然后像往常一样训练您的模型。训练完成后,您只需调用 swap_swa_sgd()
即可将模型的权重设置为其 SWA 平均值。下面我们将详细解释 SWA 程序和 SWA
类的参数。我们强调,SWA 可以与任何优化程序(例如 Adam)结合使用,就像它可以与 SGD 结合使用一样。
这只是平均 SGD 吗?
在高层次上,平均 SGD 迭代可以追溯到凸优化中的几十年 [6, 7],在凸优化中,它有时被称为 Polyak-Ruppert 平均或平均 SGD。但细节很重要。平均 SGD 通常与衰减的学习率和指数移动平均值结合使用,通常用于凸优化。在凸优化中,重点一直放在提高收敛速度上。在深度学习中,这种形式的平均 SGD 平滑了 SGD 迭代的轨迹,但性能差异不大。
相比之下,SWA 侧重于 SGD 迭代的等权重平均,并采用修改后的循环或高恒定学习率,并利用深度学习特有的训练目标的平坦性 [8] 来改进泛化。
随机权重平均
SWA 工作原理有两个重要因素。首先,SWA 使用修改后的学习率计划,以便 SGD 继续探索高性能网络集,而不是简单地收敛到单个解决方案。例如,我们可以在前 75% 的训练时间内使用标准衰减学习率策略,然后在剩余 25% 的时间内将学习率设置为合理的高恒定值(参见下图 2)。第二个因素是平均 SGD 遍历的网络的权重。例如,我们可以维护在最后 25% 的训练时间内每个周期结束时获得的权重的运行平均值(参见图 2)。

图 2. SWA 采用的学习率计划说明。标准衰减计划用于前 75% 的训练,然后高恒定值用于剩余 25%。SWA 平均值在最后 25% 的训练期间形成。
在我们的实现中,SWA
优化器的自动模式允许我们运行上述程序。要在自动模式下运行 SWA,您只需使用 SWA(base_opt, swa_start, swa_freq, swa_lr)
包装您选择的优化器 base_opt
(可以是 SGD、Adam 或任何其他 torch.optim.Optimizer
)。在 swa_start
优化步骤之后,学习率将切换为恒定值 swa_lr
,并且在每个 swa_freq
优化步骤结束时,权重的快照将添加到 SWA 运行平均值中。运行 opt.swap_swa_sgd()
后,模型的权重将替换为其 SWA 运行平均值。
批量归一化
需要记住的一个重要细节是批量归一化。批量归一化层在训练期间计算激活的运行统计数据。请注意,权重的 SWA 平均值永远不会用于在训练期间进行预测,因此在您使用 opt.swap_swa_sgd()
重置模型权重后,批量归一化层没有计算激活统计数据。要计算激活统计数据,您只需在训练完成后使用 SWA 模型对您的训练数据进行一次前向传递。在 SWA
类中,我们提供了一个辅助函数 opt.bn_update(train_loader, model)
。它通过对 train_loader
数据加载器进行前向传递来更新模型中每个批量归一化层的激活统计数据。您只需要在训练结束时调用此函数一次。
高级学习率计划
SWA 可以与任何鼓励探索平坦解区域的学习率计划一起使用。例如,您可以在最后 25% 的训练时间内使用循环学习率而不是恒定值,并平均每个周期内学习率最低值对应的网络的权重(参见图 3)。

图 3. 具有替代学习率计划的 SWA 说明。循环学习率在最后 25% 的训练中采用,并且用于平均的模型在每个周期结束时收集。
在我们的实现中,您可以使用手动模式下的 SWA
来实现自定义学习率和权重平均策略。以下代码等效于本博文开头介绍的自动模式代码。
opt = torchcontrib.optim.SWA(base_opt)
for i in range(100):
opt.zero_grad()
loss_fn(model(input), target).backward()
opt.step()
if i > 10 and i % 5 == 0:
opt.update_swa()
opt.swap_swa_sgd()
在手动模式下,您无需指定 swa_start
、swa_lr
和 swa_freq
,只需在您想要更新 SWA 运行平均值时调用 opt.update_swa()
(例如在每个学习率周期结束时)。在手动模式下,SWA
不会更改学习率,因此您可以像通常使用任何其他 torch.optim.Optimizer
一样使用您想要的任何计划。
为什么它有效?
SGD 收敛到损失平坦区域内的解决方案。权重空间是极高维的,并且平坦区域的大部分体积都集中在边界附近,因此 SGD 解决方案将始终在损失平坦区域的边界附近找到。另一方面,SWA 平均多个 SGD 解决方案,这使其能够朝着平坦区域的中心移动。
我们期望位于损失平坦区域中心的解决方案比边界附近的解决方案具有更好的泛化能力。实际上,训练和测试误差面在权重空间中并非完美对齐。位于平坦区域中心的解决方案不如边界附近的解决方案容易受到训练和测试误差面之间偏移的影响。在下图 4 中,我们显示了沿连接 SWA 和 SGD 解决方案方向的训练损失和测试误差面。如您所见,虽然 SWA 解决方案的训练损失高于 SGD 解决方案,但它位于低损失区域的中心,并且具有明显更好的测试误差。

图 4. 沿连接 SWA 解决方案(圆圈)和 SGD 解决方案(正方形)的线的训练损失和测试误差。SWA 解决方案位于训练损失低的宽阔区域的中心,而 SGD 解决方案位于边界附近。由于训练损失和测试误差面之间的偏移,SWA 解决方案导致更好的泛化。
示例和结果
我们发布了一个 GitHub 代码库 此处,其中包含使用 torchcontrib
实现的 SWA 来训练 DNN 的示例。例如,这些示例可用于在 CIFAR-100 上实现以下结果
DNN(预算) | SGD | SWA 1 个预算 | SWA 1.25 个预算 | SWA 1.5 个预算 |
---|---|---|---|---|
VGG16 (200) | 72.55 ± 0.10 | 73.91 ± 0.12 | 74.17 ± 0.15 | 74.27 ± 0.25 |
PreResNet110 (150) | 76.77 ± 0.38 | 78.75 ± 0.16 | 78.91 ± 0.29 | 79.10 ± 0.21 |
PreResNet164 (150) | 78.49 ± 0.36 | 79.77 ± 0.17 | 80.18 ± 0.23 | 80.35 ± 0.16 |
WideResNet28x10 (200) | 80.82 ± 0.23 | 81.46 ± 0.23 | 81.91 ± 0.27 | 82.15 ± 0.27 |
半监督学习
在一篇后续 论文 中,SWA 被应用于半监督学习,其中它展示了超出多个设置中最佳报告结果的改进。例如,使用 SWA,如果您只有 4k 训练数据点的训练标签,您可以在 CIFAR-10 上获得 95% 的准确率(该问题之前报告的最佳结果为 93.7%)。本文还探讨了在周期内多次平均,这可以加速收敛并在给定的时间内找到更平坦的解决方案。

图 5. fast-SWA 在 CIFAR-10 半监督学习中的性能。fast-SWA 在每个考虑的设置中都取得了创纪录的成绩。
校准和不确定性估计
SWA-Gaussian (SWAG) 是一种简单、可扩展且方便的方法,用于贝叶斯深度学习中的不确定性估计和校准。与 SWA 类似,SWA 保持 SGD 迭代的运行平均值,SWAG 估计迭代的一阶矩和二阶矩,以构建权重上的高斯分布。SWAG 分布近似于真实后验的形状:下图 6 显示了 CIFAR-100 上 PreResNet-164 的后验对数密度之上的 SWAG 分布。

图 6. CIFAR-100 上 PreResNet-164 的后验对数密度之上的 SWAG 分布。SWAG 分布的形状与后验对齐。
经验表明,在计算机视觉任务中的不确定性量化、分布外检测、校准和迁移学习方面,SWAG 的性能与流行的替代方案(包括 MC dropout、KFAC Laplace 和温度缩放)相当或更好。SWAG 的代码可在 此处 获取。
强化学习
在另一篇后续 论文 中,SWA 被证明可以提高策略梯度方法 A2C 和 DDPG 在多个 Atari 游戏和 MuJoCo 环境中的性能。
环境 | A2C | A2C + SWA |
---|---|---|
Breakout | 522 ± 34 | 703 ± 60 |
Qbert | 18777 ± 778 | 21272 ± 655 |
SpaceInvaders | 7727 ± 1121 | 21676 ± 8897 |
Seaquest | 1779 ± 4 | 1795 ± 4 |
CrazyClimber | 147030 ± 10239 | 139752 ± 11618 |
BeamRider | 9999 ± 402 | 11321 ± 1065 |
低精度训练
我们可以通过将向下舍入的权重与向上舍入的权重相结合来过滤量化噪声。此外,通过平均权重以找到损失表面的平坦区域,权重的较大扰动不会影响解决方案的质量(图 7 和 8)。最近的工作表明,通过将 SWA 适应于低精度设置(在一种称为 SWALP 的方法中),即使在所有训练都在 8 位的情况下,也可以匹敌全精度 SGD 的性能 [5]。这是一个非常实际重要的结果,因为 (1) 8 位 SGD 训练的性能明显比全精度 SGD 差,并且 (2) 低精度训练比训练后低精度预测(常用设置)困难得多。例如,在 CIFAR-100 上使用浮点 (16 位) SGD 训练的 ResNet-164 实现了 22.2% 的误差,而 8 位 SGD 实现了 24.0% 的误差。相比之下,使用 8 位训练的 SWALP 实现了 21.8% 的误差。

图 7. 在平坦区域中量化仍然可以提供低损耗的解决方案。

图 8. 低精度 SGD 训练(使用修改后的学习率计划)和 SWALP。
结论
深度学习中最伟大的未解决问题之一是,鉴于训练目标是高度多模态的,并且原则上存在许多实现无训练损失但泛化能力差的参数设置,为什么 SGD 设法找到好的解决方案。通过理解与泛化相关的平坦度等几何特征,我们可以开始解决这些问题,并构建提供更好泛化能力以及许多其他有用功能(例如不确定性表示)的优化器。我们介绍了 SWA,这是一种标准 SGD 的简单直接替代品,原则上可以使任何训练深度神经网络的人受益。SWA 已被证明在许多领域都具有强大的性能,包括计算机视觉、半监督学习、强化学习、不确定性表示、校准、贝叶斯模型平均和低精度训练。
我们鼓励您试用 SWA!现在使用 SWA 就像使用 PyTorch 中的任何其他优化器一样简单。即使您已经使用 SGD(或任何其他优化器)训练了模型,也可以通过使用预训练模型启动 SWA 少量周期来轻松实现 SWA 的优势。
- [1] Averaging Weights Leads to Wider Optima and Better Generalization; Pavel Izmailov, Dmitry Podoprikhin, Timur Garipov, Dmitry Vetrov, Andrew Gordon Wilson; Uncertainty in Artificial Intelligence (UAI), 2018
- [2] There Are Many Consistent Explanations of Unlabeled Data: Why You Should Average; Ben Athiwaratkun, Marc Finzi, Pavel Izmailov, Andrew Gordon Wilson; International Conference on Learning Representations (ICLR), 2019
- [3] Improving Stability in Deep Reinforcement Learning with Weight Averaging; Evgenii Nikishin, Pavel Izmailov, Ben Athiwaratkun, Dmitrii Podoprikhin, Timur Garipov, Pavel Shvechikov, Dmitry Vetrov, Andrew Gordon Wilson, UAI 2018 Workshop: Uncertainty in Deep Learning, 2018
- [4] A Simple Baseline for Bayesian Uncertainty in Deep Learning, Wesley Maddox, Timur Garipov, Pavel Izmailov, Andrew Gordon Wilson, arXiv pre-print, 2019: https://arxiv.org/abs/1902.02476
- [5] SWALP : Stochastic Weight Averaging in Low Precision Training, Guandao Yang, Tianyi Zhang, Polina Kirichenko, Junwen Bai, Andrew Gordon Wilson, Christopher De Sa, To appear at the International Conference on Machine Learning (ICML), 2019.
- [6] David Ruppert. Efficient estimations from a slowly convergent Robbins-Monro process. Technical report, Cornell University Operations Research and Industrial Engineering, 1988.
- [7] Acceleration of stochastic approximation by averaging. Boris T Polyak and Anatoli B Juditsky. SIAM Journal on Control and Optimization, 30(4):838–855, 1992.
- [8] Loss Surfaces, Mode Connectivity, and Fast Ensembling of DNNs, Timur Garipov, Pavel Izmailov, Dmitrii Podoprikhin, Dmitry Vetrov, Andrew Gordon Wilson. Neural Information Processing Systems (NeurIPS), 2018