跳转到主要内容
博客

PyTorch 1.6 现已包含随机权重平均

作者: 2020 年 8 月 18 日2024 年 11 月 16 日暂无评论

您使用随机梯度下降 (SGD) 还是 Adam?无论您使用哪种程序来训练神经网络,通过 PyTorch 1.6 现在原生支持的一种简单新技术——随机权重平均 (SWA) [1],您很可能以几乎零附加成本实现显著更好的泛化能力。即使您已经训练了模型,也可以通过从预训练模型开始,运行 SWA 少量 epoch 来轻松实现 SWA 的优势。研究人员一次又一次地发现,SWA 以低成本或低工作量改善了在各种实际应用中经过良好调整的模型的性能!

SWA 具有广泛的应用范围和功能

  • 与计算机视觉中的标准训练技术相比,SWA 显著提高了性能(例如,ImageNet 和 CIFAR 基准测试上的 VGG、ResNets、Wide ResNets 和 DenseNets [1, 2])。
  • SWA 在半监督学习和域适应的关键基准上提供了最先进的性能 [2]
  • SWA 被证明可以改善语言建模(例如,WikiText-2 上的 AWD-LSTM [4])和深度强化学习中的策略梯度方法 [3] 的性能。
  • SWAG 是 SWA 的扩展,可以在贝叶斯深度学习中近似贝叶斯模型平均,并在各种设置中实现最先进的不确定性校准结果。此外,其最近的泛化 MultiSWAG 提供了显著的额外性能增益并减轻了双下降 [4, 10]。另一种方法 Subspace Inference 在 SWA 解决方案周围的参数空间的一个小子空间中近似贝叶斯后验 [5]
  • 用于低精度训练的 SWA,SWALP,即使所有数字(包括梯度累加器)都量化到 8 位,也能与全精度 SGD 训练的性能相匹配 [6]
  • 并行 SWA,SWAP,通过使用大批量大小被证明可以大大加快神经网络的训练速度,特别是通过在 27 秒内将神经网络训练到 CIFAR-10 上 94% 的准确率创造了记录 [11]

图 1。CIFAR-100 上带有 Preactivation ResNet-164 的 SWA 和 SGD 的说明 [1]左图:三个 FGE 样本及其相应的 SWA 解决方案(在权重空间中平均)的测试误差曲面。中图右图:显示 SGD(收敛时)和 SWA 提出的权重(从 125 个训练 epoch 后 SGD 的相同初始化开始)的测试误差和训练损失曲面。有关这些图的构建方式的详细信息,请参阅 [1]

简而言之,SWA 对 SGD(或任何随机优化器)遍历的权重进行等权平均,并采用修改后的学习率调度(参见图 1 的左面板)。SWA 解决方案最终位于损失的宽平坦区域的中心,而 SGD 倾向于收敛到低损失区域的边界,使其容易受到训练和测试误差曲面之间偏移的影响(参见图 1 的中面板和右面板)。我们强调 SWA 可以与任何优化器(例如 Adam)一起使用,并且不限于 SGD

以前,SWA 在 PyTorch contrib 中。在 PyTorch 1.6 中,我们在 torch.optim.swa_utils 中提供了 SWA 的新便捷实现。

这仅仅是平均 SGD 吗?

从高层次来看,平均 SGD 迭代可以追溯到几十年前的凸优化 [7, 8],在那里它有时被称为 Polyak-Ruppert 平均或平均 SGD。但细节很重要。平均 SGD 通常与衰减学习率和指数移动平均 (EMA) 结合使用,通常用于凸优化。在凸优化中,重点在于提高收敛速度。在深度学习中,这种形式的平均 SGD 平滑了 SGD 迭代的轨迹,但表现并没有太大不同。

相比之下,SWA 使用 SGD 迭代的等权平均,并采用修改后的循环或高常数学习率,并利用深度学习特有的训练目标平坦性 [8]提高泛化能力

随机权重平均如何工作?

使 SWA 起作用有两个重要因素。首先,SWA 使用修改后的学习率调度,使 SGD(或其他优化器,如 Adam)继续在最优值附近波动并探索不同的模型,而不是简单地收敛到单个解决方案。例如,我们可以在前 75% 的训练时间使用标准衰减学习率策略,然后将学习率设置为一个合理的常数高值,用于剩余的 25% 的时间(参见下面的图 2)。第二个因素是对 SGD 遍历的网络权重进行平均(通常是等权平均)。例如,我们可以在训练时间的最后 25% 的每个 epoch 结束时维护所获得的权重的运行平均值(参见图 2)。训练完成后,我们将网络权重设置为计算出的 SWA 平均值。

图 2。SWA 采用的学习率调度说明。前 75% 的训练使用标准衰减调度,剩余 25% 使用高常数值。SWA 平均值在训练的最后 25% 期间形成。

一个重要的细节是批量归一化。批量归一化层在训练期间计算激活的运行统计量。请注意,SWA 权重的平均值在训练期间从不用于进行预测。因此,批量归一化层在训练结束时没有计算激活统计量。我们可以通过使用 SWA 模型对训练数据进行一次前向传播来计算这些统计量。

虽然我们在上述描述中为简化起见侧重于 SGD,但 SWA 可以与任何优化器结合使用。您还可以使用循环学习率而不是高常数值(例如,参见 [2])。

如何在 PyTorch 中使用 SWA?

torch.optim.swa_utils 中,我们实现了所有 SWA 成分,以便方便地将 SWA 与任何模型一起使用。特别是,我们为 SWA 模型实现了 AveragedModel 类,SWALR 学习率调度器,以及 update_bn 实用函数,用于在训练结束时更新 SWA 批量归一化统计量。

在下面的示例中,swa_model 是累积权重平均值的 SWA 模型。我们总共训练模型 300 个 epoch,并在 epoch 160 切换到 SWA 学习率调度并开始收集参数的 SWA 平均值。

from torch.optim.swa_utils import AveragedModel, SWALR
from torch.optim.lr_scheduler import CosineAnnealingLR

loader, optimizer, model, loss_fn = ...
swa_model = AveragedModel(model)
scheduler = CosineAnnealingLR(optimizer, T_max=100)
swa_start = 5
swa_scheduler = SWALR(optimizer, swa_lr=0.05)

for epoch in range(100):
      for input, target in loader:
          optimizer.zero_grad()
          loss_fn(model(input), target).backward()
          optimizer.step()
      if epoch > swa_start:
          swa_model.update_parameters(model)
          swa_scheduler.step()
      else:
          scheduler.step()

# Update bn statistics for the swa_model at the end
torch.optim.swa_utils.update_bn(loader, swa_model)
# Use swa_model to make predictions on test data 
preds = swa_model(test_input)

接下来,我们将详细解释 torch.optim.swa_utils 的每个组件。

AveragedModel 类用于计算 SWA 模型的权重。您可以通过运行 swa_model = AveragedModel(model) 创建一个平均模型。然后,您可以通过 swa_model.update_parameters(model) 更新平均模型的参数。默认情况下,AveragedModel 计算您提供的参数的运行等权平均值,但您也可以使用 avg_fn 参数自定义平均函数。在以下示例中,ema_model 计算指数移动平均值。

ema_avg = lambda averaged_model_parameter, model_parameter, num_averaged:\
0.1 * averaged_model_parameter + 0.9 * model_parameter
ema_model = torch.optim.swa_utils.AveragedModel(model, avg_fn=ema_avg)

在实践中,我们发现图 2 中带有修改学习率调度的等权平均值提供了最佳性能。

SWALR 是一个学习率调度器,它将学习率退火到一个固定值,然后保持不变。例如,以下代码创建一个调度器,该调度器在线性地将学习率从其初始值在每个参数组的 5 个 epoch 内退火到 0.05

swa_scheduler = torch.optim.swa_utils.SWALR(optimizer, 
anneal_strategy="linear", anneal_epochs=5, swa_lr=0.05)

我们还实现了余弦退火到固定值 (anneal_strategy="cos")。在实践中,我们通常在 epoch swa_start(例如,在 75% 的训练 epoch 之后)切换到 SWALR,并同时开始计算权重的运行平均值

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)
swa_start = 75
for epoch in range(100):
      # <train epoch>
      if i > swa_start:
          swa_model.update_parameters(model)
          swa_scheduler.step()
      else:
          scheduler.step()

最后,update_bn 是一个实用函数,用于计算给定数据加载器 loader 上的 SWA 模型的批归一化统计量

torch.optim.swa_utils.update_bn(loader, swa_model) 

update_bnswa_model 应用于数据加载器中的每个元素,并计算模型中每个批归一化层的激活统计量。

计算 SWA 平均值并更新批归一化层后,您可以应用 swa_model 在测试数据上进行预测。

它为什么有效?

损失曲面存在大片平坦区域 [9]。在下面的图 3 中,我们展示了参数空间子空间中损失曲面的可视化,该子空间包含连接两个独立训练的 SGD 解决方案的路径,使得沿该路径的每个点的损失都相似地低。SGD 收敛到这些区域的边界附近,因为没有太多梯度信号可以移入内部,因为该区域中的所有点都具有相似的低损失值。通过增加学习率,SWA 在这个平坦区域周围旋转,然后通过平均迭代,向平坦区域的中心移动。

图 3:CIFAR-10 数据集上没有跳过连接的 ResNet-20 的模式连通性可视化。该可视化是与 Javier Ideami (https://losslandscape.com/) 合作创建的。有关更多详细信息,请参阅这篇 博客文章

我们期望位于损失平坦区域中心的解决方案比靠近边界的解决方案具有更好的泛化能力。实际上,训练和测试误差曲面在权重空间中并非完美对齐。位于平坦区域中心的解决方案不像靠近边界的解决方案那样容易受到训练和测试误差曲面之间偏移的影响。在下面的图 4 中,我们展示了连接 SWA 和 SGD 解决方案方向上的训练损失和测试误差曲面。正如您所看到的,虽然 SWA 解决方案的训练损失高于 SGD 解决方案,但它位于低损失区域的中心,并且具有明显更好的测试误差。

图 4。连接 SWA 解决方案(圆圈)和 SGD 解决方案(方块)的直线上的训练损失和测试误差。SWA 解决方案位于低训练损失的宽阔区域的中心,而 SGD 解决方案位于边界附近。由于训练损失和测试误差曲面之间的偏移,SWA 解决方案导致更好的泛化能力。

SWA 取得了哪些成果?

我们发布了一个 GitHub 仓库,其中包含使用 PyTorch 实现 SWA 训练 DNN 的示例。例如,这些示例可用于在 CIFAR-100 上实现以下结果

 VGG-16ResNet-164WideResNet-28×10
SGD72.8 ± 0.378.4 ± 0.381.0 ± 0.3
SWA74.4 ± 0.379.8 ± 0.482.5 ± 0.2

半监督学习

在后续的论文中,SWA 被应用于半监督学习,并在多种设置中改进了最佳报告结果[2]。例如,如果只有 4k 训练数据点的训练标签,使用 SWA 您可以在 CIFAR-10 上获得 95% 的准确率(此问题上之前的最佳报告结果是 93.7%)。这篇论文还探讨了在 epoch 内多次平均,这可以加速收敛并在给定时间内找到更平坦的解决方案。

图 5。快速 SWA 在 CIFAR-10 半监督学习上的性能。快速 SWA 在所考虑的每种设置中都取得了创纪录的结果。

强化学习

在另一篇后续论文中,SWA 被证明可以提高策略梯度方法 A2C 和 DDPG 在多个 Atari 游戏和 MuJoCo 环境中的性能[3]。此应用也是 SWA 与 Adam 结合使用的实例。请记住,SWA 不特定于 SGD,并且基本上可以使任何优化器受益。

环境名称A2CA2C + SWA
突破522 ± 34703 ± 60
Qbert18777 ± 77821272 ± 655
太空入侵者7727 ± 112121676 ± 8897
海中救援1779 ± 41795 ± 4
光束骑士9999 ± 40211321 ± 1065
疯狂攀爬者147030 ± 10239139752 ± 11618

低精度训练

我们可以通过结合向下舍入的权重和向上舍入的权重来过滤量化噪声。此外,通过平均权重以找到损失曲面的平坦区域,权重的大的扰动不会影响解决方案的质量(图 9 和图 10)。最近的工作表明,通过将 SWA 适应低精度设置,在一种称为 SWALP 的方法中,即使所有训练都以 8 位进行,也可以匹配全精度 SGD 的性能[5]。这是一个非常实际重要的结果,因为(1)8 位 SGD 训练的性能明显低于全精度 SGD,并且(2)低精度训练比训练后低精度预测(通常设置)困难得多。例如,在 CIFAR-100 上用浮点(16 位)SGD 训练的 ResNet-164 实现了 22.2% 的误差,而 8 位 SGD 实现了 24.0% 的误差。相比之下,采用 8 位训练的 SWALP 实现了 21.8% 的误差。

图 9。量化解决方案会导致权重扰动,这对手头解决方案(左)的质量影响大于对宽泛解决方案(右)的质量影响。

图 10。标准低精度训练与 SWALP 之间的差异。

另一项工作 SQWA 提出了一种对神经网络进行低精度量化和微调的方法[12]。特别是,SQWA 在 CIFAR-100 和 ImageNet 上对量化到 2 位的 DNN 取得了最先进的结果。

校准和不确定性估计

通过在损失中找到中心解,SWA 还可以改善校准和不确定性表示。实际上,SWA 可以被视为集成的一种近似,类似于贝叶斯模型平均,但只有一个模型 [1]

SWA 可以看作是采用修改后的学习率调度对 SGD 迭代进行一阶矩计算。我们可以通过同时计算迭代的二阶矩来形成权重的 Gaussian 近似后验,从而进一步表征 SGD 迭代的损失几何,直接推广 SWA。这种方法,SWA-高斯(SWAG)是贝叶斯深度学习中不确定性估计和校准的一种简单、可扩展和方便的方法 [4]。SWAG 分布近似真实后验的形状:下面的图 6 展示了 CIFAR-10 上 ResNet-20 的 SWAG 分布和后验对数密度。

图 6。SWAG 后验近似和在 SWAG 协方差矩阵的两个最大特征值形成的子空间中,用于在 CIFAR-10 上训练的无跳过连接 ResNet-20 的损失曲面。SWAG 分布的形状与后验对齐:两个分布的峰值重合,并且两个分布在一个方向上比在正交方向上更宽。可视化由 Javier Ideami 合作创建。

从经验来看,SWAG 在计算机视觉任务中的不确定性量化、分布外检测、校准和迁移学习方面,表现与包括 MC Dropout、KFAC Laplace 和温度缩放等流行替代方法相当或更好。SWAG 的代码可在此处获取:https://github.com/wjmaddox/swa_gaussian

图 7。MultiSWAG 推广了 SWAG 和深度集成,对多个吸引盆进行贝叶斯模型平均,从而显著提高了性能。相比之下,如图所示,深度集成选择不同的模式,而标准变分推断 (VI) 在单个吸引盆内进行边缘化(模型平均)。

MultiSWAG [9] 使用多个独立的 SWAG 模型形成高斯混合作为近似后验分布。不同的吸引盆包含高度互补的数据解释。因此,对这些多个吸引盆进行边缘化可以显著提高准确性和不确定性表示。MultiSWAG 可以看作是深度集成的一种泛化,但具有性能改进。

事实上,我们在图 8 中看到 MultiSWAG 完全减轻了双下降——更灵活的模型具有单调改进的性能——并且与 SGD 相比,提供了显著改进的泛化能力。例如,当 ResNet-18 具有 20 个宽度的层时,Multi-SWAG 实现了低于 30% 的错误率,而 SGD 实现了超过 45% 的错误率,差距超过 15%!

图 8。不同宽度的 ResNet-18 在 CIFAR-100 上的 SGD、SWAG 和 Multi-SWAG。我们看到 Multi-SWAG 特别减轻了双下降,并提供了比 SGD 显著的准确性改进。

参考文献 [10] 还考虑了 Multi-SWA,它在集成中使用了多个独立训练的 SWA 解决方案,在不增加任何额外计算成本的情况下提高了深度集成的性能。MultiSWA 和 MultiSWAG 的代码可在此处获取 这里

另一种方法 Subspace Inference 在 SWA 解决方案周围构建一个低维子空间,并在该子空间中对权重进行边缘化以近似贝叶斯模型平均 [5]。Subspace Inference 利用 SGD 迭代的统计数据来构建 SWA 解决方案和子空间。该方法在分类和回归问题中都取得了预测精度和不确定性校准方面的强大性能。代码可在 这里 获取。

试一试!

深度学习中最大的开放问题之一是,为什么 SGD 能够找到好的解决方案,鉴于训练目标是高度多峰的,并且存在许多参数设置,它们实现零训练损失但泛化能力差。通过理解与泛化相关的几何特征,例如平坦度,我们可以开始解决这些问题,并构建能够提供更好泛化能力和许多其他有用功能(例如不确定性表示)的优化器。我们介绍了 SWA,它是一个简单的替代标准优化器(如 SGD 和 Adam)的即插即用型组件,原则上可以使任何训练深度神经网络的人受益。SWA 已被证明在多个领域具有强大的性能,包括计算机视觉、半监督学习、强化学习、不确定性表示、校准、贝叶斯模型平均和低精度训练。

我们鼓励您尝试 SWA!SWA 现在与 PyTorch 中的任何标准训练一样简单。即使您已经训练了模型,您也可以通过从预训练模型开始运行少量 epoch 来显著提高性能。

[1] 平均权重带来更宽的局部最优和更好的泛化;Pavel Izmailov、Dmitry Podoprikhin、Timur Garipov、Dmitry Vetrov、Andrew Gordon Wilson;人工智能不确定性 (UAI),2018。

[2] 未标记数据有许多一致的解释:为什么要进行平均;Ben Athiwaratkun、Marc Finzi、Pavel Izmailov、Andrew Gordon Wilson;国际学习表示会议 (ICLR),2019。

[3] 通过权重平均提高深度强化学习的稳定性;Evgenii Nikishin、Pavel Izmailov、Ben Athiwaratkun、Dmitrii Podoprikhin、Timur Garipov、Pavel Shvechikov、Dmitry Vetrov、Andrew Gordon Wilson;UAI 2018 研讨会:深度学习中的不确定性,2018。

[4] 深度学习中贝叶斯不确定性的简单基线 Wesley Maddox, Timur Garipov, Pavel Izmailov, Andrew Gordon Wilson; 神经信息处理系统 (NeurIPS), 2019。

[5] 贝叶斯深度学习的子空间推断 Pavel Izmailov、Wesley Maddox、Polina Kirichenko、Timur Garipov、Dmitry Vetrov、Andrew Gordon Wilson 人工智能不确定性 (UAI),2019。

[6] SWALP:低精度训练中的随机权重平均 Guandao Yang, Tianyi Zhang, Polina Kirichenko, Junwen Bai, Andrew Gordon Wilson, Christopher De Sa; 国际机器学习会议 (ICML), 2019。

[7] 大卫·鲁珀特。从缓慢收敛的 Robbins-Monro 过程中进行有效估计;技术报告,康奈尔大学运筹学与工业工程,1988 年。

[8] 通过平均加速随机近似。Boris T Polyak 和 Anatoli B Juditsky;SIAM Journal on Control and Optimization,30(4):838–855,1992。

[9] DNN 的损失曲面、模式连接和快速集成 Timur Garipov、Pavel Izmailov、Dmitrii Podoprikhin、Dmitry Vetrov、Andrew Gordon Wilson。神经信息处理系统 (NeurIPS),2018。

[10] 贝叶斯深度学习与泛化的概率视角 Andrew Gordon Wilson, Pavel Izmailov。ArXiv 预印本,2020。

[11] 并行随机权重平均:泛化能力强的大批量训练 Gupta, Vipul, Santiago Akle Serrano, and Dennis DeCoste; 国际学习表示会议 (ICLR)。2019。

[12] SQWA:随机量化权重平均,用于提高低精度深度神经网络的泛化能力 Shin, Sungho, Yoonho Boo, and Wonyong Sung; arXiv 预印本 2020。