跳转到主要内容
博客

PyTorch 中的随机权重平均

作者: 2019 年 4 月 29 日2024 年 11 月 16 日暂无评论

在这篇博文中,我们介绍了最近提出的随机权重平均(SWA)技术 [1, 2],以及它在 torchcontrib 中的新实现。SWA 是一种简单的过程,可以在不增加额外成本的情况下改善深度学习中相对于随机梯度下降(SGD)的泛化能力,并且可以作为 PyTorch 中任何其他优化器的直接替代品。SWA 具有广泛的应用和特性:

  1. SWA 已被证明能显著改善计算机视觉任务的泛化能力,包括 ImageNet 和 CIFAR 基准上的 VGG、ResNets、Wide ResNets 和 DenseNets [1, 2]
  2. SWA 在半监督学习和域适应的关键基准上提供了最先进的性能 [2]
  3. SWA 被证明可以提高深度强化学习中策略梯度方法的训练稳定性以及最终的平均奖励 [3]
  4. SWA 的扩展可以实现高效的贝叶斯模型平均,以及深度学习中高质量的不确定性估计和校准 [4]
  5. 用于低精度训练的 SWA(SWALP)即使所有数字(包括梯度累加器)都量化到 8 位,也能与全精度 SGD 的性能相匹配 [5]

简而言之,SWA 对 SGD 遍历的权重进行等权平均,并采用修改后的学习率调度(见图 1 的左图)。SWA 解决方案最终位于损失的宽平坦区域的中心,而 SGD 倾向于收敛到低损失区域的边界,这使得它容易受到训练误差和测试误差曲面之间偏移的影响(见图 1 的中图和右图)。

图 1. CIFAR-100 上使用 Preactivation ResNet-164 的 SWA 和 SGD 示例 [1]左:三个 FGE 样本和相应的 SWA 解决方案(权重空间中的平均值)的测试误差曲面。右:测试误差和训练损失曲面,显示了从相同的 SGD 初始值开始,经过 125 个训练 epoch 后,SGD(收敛时)和 SWA 建议的权重。有关这些图的构建细节,请参见 [1]

通过我们在 torchcontrib 中的新实现,使用 SWA 就像使用 PyTorch 中的任何其他优化器一样简单。

from torchcontrib.optim import SWA

...
...

# training loop
base_opt = torch.optim.SGD(model.parameters(), lr=0.1)
opt = torchcontrib.optim.SWA(base_opt, swa_start=10, swa_freq=5, swa_lr=0.05)
for _ in range(100):
     opt.zero_grad()
     loss_fn(model(input), target).backward()
     opt.step()
opt.swap_swa_sgd()

您可以使用 SWA 类封装 torch.optim 中的任何优化器,然后像往常一样训练您的模型。训练完成后,您只需调用 swap_swa_sgd() 即可将模型的权重设置为其 SWA 平均值。下面我们将详细解释 SWA 过程和 SWA 类的参数。我们强调,SWA 可以与任何优化过程(例如 Adam)结合使用,就像它与 SGD 结合使用一样。

这仅仅是平均 SGD 吗?

从高层次来看,对 SGD 迭代进行平均可以追溯到凸优化领域的几十年前 [6, 7],有时被称为 Polyak-Ruppert 平均,或平均 SGD。但细节很重要平均 SGD 通常与衰减学习率和指数移动平均结合使用,通常用于凸优化。在凸优化中,重点在于提高收敛速度。在深度学习中,这种形式的平均 SGD 平滑了 SGD 迭代的轨迹,但表现并没有太大不同。

相比之下,SWA 侧重于对 SGD 迭代进行等权平均,采用修改后的循环或高恒定学习率,并利用深度学习特有的训练目标的平坦性 [8],以实现改进的泛化能力

随机权重平均

SWA 的成功有两个重要因素。首先,SWA 使用修改后的学习率调度,以便 SGD 继续探索高性能网络集,而不是简单地收敛到单个解决方案。例如,我们可以对前 75% 的训练时间使用标准的衰减学习率策略,然后将学习率设置为合理的恒定高值,用于剩余的 25% 时间(见下图 2)。第二个因素是平均 SGD 遍历的网络权重。例如,我们可以在训练时间后 25% 的每个 epoch 结束时保持权重的运行平均值(见图 2)。

图 2. SWA 采用的学习率调度示例。前 75% 的训练使用标准衰减调度,然后剩余 25% 使用高恒定值。SWA 平均值在训练的最后 25% 形成。

在我们的实现中,SWA 优化器的自动模式允许我们运行上述过程。要在自动模式下运行 SWA,您只需将您选择的优化器 base_opt(可以是 SGD、Adam 或任何其他 torch.optim.Optimizer)与 SWA(base_opt, swa_start, swa_freq, swa_lr) 包装起来。在 swa_start 优化步骤之后,学习率将切换到一个常量值 swa_lr,并且在每 swa_freq 优化步骤结束时,权重的快照将被添加到 SWA 运行平均值中。一旦您运行 opt.swap_swa_sgd(),您的模型的权重将被其 SWA 运行平均值替换。

批归一化

一个需要记住的重要细节是批归一化。批归一化层在训练期间计算激活的运行统计数据。请注意,SWA 的权重平均值在训练期间从不用于进行预测,因此在您使用 opt.swap_swa_sgd() 重置模型权重后,批归一化层不会计算激活统计数据。要计算激活统计数据,您只需在训练完成后使用 SWA 模型对您的训练数据进行一次前向传播即可。在 SWA 类中,我们提供了一个辅助函数 opt.bn_update(train_loader, model)。它通过对 train_loader 数据加载器的数据进行前向传播,更新模型中每个批归一化层的激活统计数据。您只需要在训练结束时调用此函数一次。

高级学习率调度

SWA 可以与任何鼓励探索解决方案平坦区域的学习率调度结合使用。例如,您可以在训练时间的最后 25% 使用循环学习率而不是恒定值,并平均每个周期内学习率最低值对应的网络权重(见图 3)。

图 3. 使用替代学习率调度的 SWA 示例。在训练的最后 25% 采用循环学习率,并在每个周期结束时收集用于平均的模型。

在我们的实现中,您可以通过在手动模式下使用 SWA 来实现自定义学习率和权重平均策略。以下代码等同于本博文开头介绍的自动模式代码。

opt = torchcontrib.optim.SWA(base_opt)
for i in range(100):
    opt.zero_grad()
    loss_fn(model(input), target).backward()
    opt.step()
    if i > 10 and i % 5 == 0:
        opt.update_swa()
opt.swap_swa_sgd()

在手动模式下,您无需指定 swa_startswa_lrswa_freq,只需在需要更新 SWA 运行平均值时(例如在每个学习率周期结束时)调用 opt.update_swa()。在手动模式下,SWA 不会改变学习率,因此您可以像使用任何其他 torch.optim.Optimizer 一样,使用任何您想要的调度。

它为什么有效?

SGD 收敛到损失的宽平坦区域内的解决方案。权重空间是极其高维的,平坦区域的大部分体积都集中在边界附近,因此 SGD 解决方案总是会在损失的平坦区域的边界附近找到。另一方面,SWA 平均了多个 SGD 解决方案,这使其能够向平坦区域的中心移动。

我们预计位于损失平坦区域中心的解决方案比边界附近的解决方案泛化能力更好。实际上,训练误差和测试误差曲面在权重空间中并非完全对齐。位于平坦区域中心的解决方案不像边界附近的解决方案那样容易受到训练误差和测试误差曲面之间偏移的影响。在下面的图 4 中,我们展示了连接 SWA 和 SGD 解决方案方向上的训练损失和测试误差曲面。正如您所看到的,虽然 SWA 解决方案的训练损失高于 SGD 解决方案,但它位于低损失区域的中心,并且具有明显更好的测试误差。

图 4. 沿连接 SWA 解决方案(圆圈)和 SGD 解决方案(方块)的线的训练损失和测试误差。SWA 解决方案位于低训练损失的宽区域的中心,而 SGD 解决方案位于边界附近。由于训练损失和测试误差曲面之间的偏移,SWA 解决方案带来了更好的泛化能力。

示例和结果

我们在此发布了一个 GitHub 仓库 here,其中包含使用 torchcontrib 实现 SWA 训练 DNN 的示例。例如,这些示例可用于在 CIFAR-100 上实现以下结果:

DNN (预算)SGDSWA 1 预算SWA 1.25 预算SWA 1.5 预算
VGG16 (200)72.55 ± 0.1073.91 ± 0.1274.17 ± 0.1574.27 ± 0.25
PreResNet110 (150)76.77 ± 0.3878.75 ± 0.1678.91 ± 0.2979.10 ± 0.21
PreResNet164 (150)78.49 ± 0.3679.77 ± 0.1780.18 ± 0.2380.35 ± 0.16
WideResNet28x10 (200)80.82 ± 0.2381.46 ± 0.2381.91 ± 0.2782.15 ± 0.27

半监督学习

在后续的 论文 中,SWA 被应用于半监督学习,它在多种设置下都展示了超越最佳报告结果的改进。例如,如果您只有 4k 训练数据点的训练标签,使用 SWA 可以在 CIFAR-10 上获得 95% 的准确率(此问题之前报告的最佳结果是 93.7%)。这篇论文还探讨了在 epoch 内多次平均,这可以加速收敛并在给定时间内找到更平坦的解决方案。

图 5. CIFAR-10 半监督学习中 fast-SWA 的性能。fast-SWA 在所考虑的每个设置中都取得了创纪录的结果。

校准和不确定性估计

SWA-高斯(SWAG)是一种简单、可扩展且方便的方法,用于贝叶斯深度学习中的不确定性估计和校准。与 SWA 类似,SWA 维护 SGD 迭代的运行平均值,而 SWAG 估计迭代的一阶和二阶矩以构建权重的高斯分布。SWAG 分布近似于真实后验的形状:下面的图 6 显示了 CIFAR-100 上 PreResNet-164 的 SWAG 分布在后验对数密度之上的情况。

图 6. CIFAR-100 上 PreResNet-164 的 SWAG 分布在后验对数密度之上。SWAG 分布的形状与后验对齐。

从经验上讲,SWAG 在计算机视觉任务中的不确定性量化、分布外检测、校准和迁移学习方面,与包括 MC dropout、KFAC Laplace 和温度缩放等流行替代方法相比,表现相当或更好。SWAG 的代码可在此处获取:https://github.com/wjmaddox/swa_gaussian

强化学习

在另一篇后续 论文 中,SWA 被证明可以提高策略梯度方法 A2C 和 DDPG 在多个 Atari 游戏和 MuJoCo 环境中的性能。

环境A2CA2C + SWA
突破522 ± 34703 ± 60
Qbert18777 ± 77821272 ± 655
太空入侵者7727 ± 112121676 ± 8897
海中救援1779 ± 41795 ± 4
疯狂攀爬者147030 ± 10239139752 ± 11618
光束骑士9999 ± 40211321 ± 1065

低精度训练

我们可以通过结合向下取整的权重和向上取整的权重来过滤量化噪声。此外,通过平均权重以找到损失曲面的平坦区域,权重的较大扰动不会影响解决方案的质量(图 7 和 8)。最近的研究表明,通过将 SWA 适应到低精度设置,一种称为 SWALP 的方法,即使所有训练都以 8 位进行,也能与全精度 SGD 的性能相匹配 [5]。这是一个非常实际且重要的结果,因为 (1) 8 位 SGD 训练的性能明显低于全精度 SGD,以及 (2) 低精度训练比训练后进行低精度预测(通常设置)困难得多。例如,在 CIFAR-100 上使用浮点(16 位)SGD 训练的 ResNet-164 达到 22.2% 的误差,而 8 位 SGD 达到 24.0% 的误差。相比之下,使用 8 位训练的 SWALP 达到 21.8% 的误差。

图 7. 在平坦区域进行量化仍然可以提供低损失的解决方案。

图 8. 低精度 SGD 训练(带有修改后的学习率调度)和 SWALP。

结论

深度学习中最大的开放问题之一是,为什么 SGD 能够找到好的解决方案,鉴于训练目标是高度多模态的,并且原则上存在许多参数设置可以实现零训练损失但泛化能力差。通过理解与泛化相关的几何特征,例如平坦性,我们可以开始解决这些问题并构建提供更好泛化能力和许多其他有用功能(例如不确定性表示)的优化器。我们介绍了 SWA,它是一个简单的标准 SGD 直接替代品,原则上可以使任何训练深度神经网络的人受益。SWA 已被证明在许多领域具有强大的性能,包括计算机视觉、半监督学习、强化学习、不确定性表示、校准、贝叶斯模型平均和低精度训练。

我们鼓励您尝试 SWA!现在使用 SWA 就像使用 PyTorch 中的任何其他优化器一样简单。即使您已经使用 SGD(或任何其他优化器)训练了模型,也可以通过从预训练模型开始运行少量 epoch 的 SWA 来轻松实现 SWA 的优势。

[1] 平均权重带来更宽的局部最优和更好的泛化;Pavel Izmailov, Dmitry Podoprikhin, Timur Garipov, Dmitry Vetrov, Andrew Gordon Wilson;人工智能不确定性 (UAI),2018

[2] 无标签数据有许多一致的解释:为什么你应该平均;Ben Athiwaratkun, Marc Finzi, Pavel Izmailov, Andrew Gordon Wilson;国际学习表示会议 (ICLR),2019

[3] 通过权重平均提高深度强化学习的稳定性;Evgenii Nikishin, Pavel Izmailov, Ben Athiwaratkun, Dmitrii Podoprikhin, Timur Garipov, Pavel Shvechikov, Dmitry Vetrov, Andrew Gordon Wilson, UAI 2018 研讨会:深度学习中的不确定性,2018

[4] 深度学习中贝叶斯不确定性的一个简单基线,Wesley Maddox, Timur Garipov, Pavel Izmailov, Andrew Gordon Wilson, arXiv 预印本,2019:https://arxiv.org/abs/1902.02476

[5] SWALP:低精度训练中的随机权重平均,Guandao Yang, Tianyi Zhang, Polina Kirichenko, Junwen Bai, Andrew Gordon Wilson, Christopher De Sa, 将发表于国际机器学习会议 (ICML),2019。

[6] David Ruppert. 缓慢收敛 Robbins-Monro 过程的有效估计. 技术报告, 康奈尔大学运筹学与工业工程系, 1988。

[7] 通过平均加速随机逼近。Boris T Polyak 和 Anatoli B Juditsky。控制与优化 SIAM 期刊,30(4):838–855,1992。

[8] 损失曲面、模式连接和 DNN 的快速集成,Timur Garipov, Pavel Izmailov, Dmitrii Podoprikhin, Dmitry Vetrov, Andrew Gordon Wilson。神经信息处理系统 (NeurIPS),2018