快捷方式

torch.optim

torch.optim 是一个实现各种优化算法的包。

最常用的方法已经得到支持,并且接口足够通用,以便将来更容易地集成更复杂的方法。

如何使用优化器

要使用 torch.optim,您必须构建一个优化器对象,该对象将保存当前状态并根据计算出的梯度更新参数。

构建优化器

要构建一个 Optimizer,您必须向其提供一个包含要优化的参数(全部应为 Parameter)或命名参数((str, Parameter) 元组)的可迭代对象。然后,您可以指定优化器特定的选项,例如学习率、权重衰减等。

示例

optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
optimizer = optim.Adam([var1, var2], lr=0.0001)

命名参数示例

optimizer = optim.SGD(model.named_parameters(), lr=0.01, momentum=0.9)
optimizer = optim.Adam([('layer0', var1), ('layer1', var2)], lr=0.0001)

每参数选项

优化器 (Optimizer) 也支持指定每参数选项。为此,您无需传递 Variable 的可迭代对象,而是传递 dict 的可迭代对象。它们中的每一个都将定义一个独立的参数组,并且应包含一个 params 键,该键包含属于该组的参数列表。其他键应与优化器接受的关键字参数匹配,并将用作此组的优化选项。

例如,这在需要指定每层学习率时非常有用

optim.SGD([
                {'params': model.base.parameters(), 'lr': 1e-2},
                {'params': model.classifier.parameters()}
            ], lr=1e-3, momentum=0.9)

optim.SGD([
                {'params': model.base.named_parameters(), 'lr': 1e-2},
                {'params': model.classifier.named_parameters()}
            ], lr=1e-3, momentum=0.9)

这意味着 model.base 的参数将使用 1e-2 的学习率,而 model.classifier 的参数将保持默认学习率 1e-3。最后,所有参数都将使用 0.9 的动量。

注意

您仍然可以将选项作为关键字参数传递。这些选项将用作默认值,用于未覆盖它们的参数组。当您只想改变一个选项,同时保持所有其他选项在参数组之间一致时,这很有用。

另请考虑以下与参数的独特惩罚相关的示例。请记住,parameters() 返回一个包含所有可学习参数的可迭代对象,包括偏置项和其他可能偏好独特惩罚的参数。为了解决这个问题,可以为每个参数组指定单独的惩罚权重

bias_params = [p for name, p in self.named_parameters() if 'bias' in name]
others = [p for name, p in self.named_parameters() if 'bias' not in name]

optim.SGD([
                {'params': others},
                {'params': bias_params, 'weight_decay': 0}
            ], weight_decay=1e-2, lr=1e-2)

通过这种方式,偏置项与非偏置项分开处理,并且专门为偏置项设置了 0weight_decay,以避免对该组施加任何惩罚。

执行优化步进

所有优化器都实现了 step() 方法,该方法用于更新参数。它可以通过两种方式使用

optimizer.step()

这是大多数优化器支持的简化版本。一旦使用例如 backward() 计算出梯度,就可以调用此函数。

示例

for input, target in dataset:
    optimizer.zero_grad()
    output = model(input)
    loss = loss_fn(output, target)
    loss.backward()
    optimizer.step()

optimizer.step(closure)

一些优化算法,例如共轭梯度法和 LBFGS,需要多次重新评估函数,因此您必须传入一个闭包,允许它们重新计算您的模型。该闭包应清除梯度,计算损失,并返回该损失值。

示例

for input, target in dataset:
    def closure():
        optimizer.zero_grad()
        output = model(input)
        loss = loss_fn(output, target)
        loss.backward()
        return loss
    optimizer.step(closure)

基类

class torch.optim.Optimizer(params, defaults)[源码][源码]

所有优化器的基类。

警告

参数需要指定为具有确定性顺序的集合,并且该顺序在不同运行之间保持一致。不满足这些属性的对象示例包括集合和字典值的迭代器。

参数
  • params (iterable) – `torch.Tensor` 或 dict 的可迭代对象。指定应优化哪些 Tensor。

  • defaults (dict[str, Any]) – (dict):一个字典,包含优化选项的默认值(当参数组未指定时使用)。

Optimizer.add_param_group

Optimizerparam_groups 添加一个参数组。

Optimizer.load_state_dict

加载优化器状态。

Optimizer.register_load_state_dict_pre_hook

注册一个 load_state_dict 前置钩子,它将在调用 load_state_dict() 之前被调用。其签名应如下::。

Optimizer.register_load_state_dict_post_hook

注册一个 load_state_dict 后置钩子,它将在调用 load_state_dict() 之后被调用。其签名应如下::。

Optimizer.state_dict

将优化器的状态作为 dict 返回。

Optimizer.register_state_dict_pre_hook

注册一个 state dict 前置钩子,它将在调用 state_dict() 之前被调用。

Optimizer.register_state_dict_post_hook

注册一个 state dict 后置钩子,它将在调用 state_dict() 之后被调用。

Optimizer.step

执行单次优化步进以更新参数。

Optimizer.register_step_pre_hook

注册一个优化器步进前置钩子,它将在优化器步进之前被调用。

Optimizer.register_step_post_hook

注册一个优化器步进后置钩子,它将在优化器步进之后被调用。

Optimizer.zero_grad

重置所有优化的 torch.Tensor 的梯度。

算法

Adadelta

实现 Adadelta 算法。

Adafactor

实现 Adafactor 算法。

Adagrad

实现 Adagrad 算法。

Adam

实现 Adam 算法。

AdamW

实现 AdamW 算法,其中权重衰减不会累积到动量或方差中。

SparseAdam

SparseAdam 实现了 Adam 算法的掩码版本,适用于稀疏梯度。

Adamax

实现 Adamax 算法(基于无穷范数的 Adam 变体)。

ASGD

实现平均随机梯度下降 (Averaged Stochastic Gradient Descent)。

LBFGS

实现 L-BFGS 算法。

NAdam

实现 NAdam 算法。

RAdam

实现 RAdam 算法。

RMSprop

实现 RMSprop 算法。

Rprop

实现弹性反向传播算法 (Resilient Backpropagation)。

SGD

实现随机梯度下降(可选择带有动量)。

我们的许多算法都有针对性能、可读性或通用性优化的各种实现方式,因此如果用户未指定特定实现,我们尝试默认选择当前设备上通常最快的实现。

我们有 3 类主要的实现:for 循环、foreach (多 Tensor) 和 fused (融合)。最直接的实现是参数上的 for 循环,其中包含大量的计算块。For 循环通常比我们的 foreach 实现慢,因为 foreach 实现将参数组合成一个多 Tensor,并一次性运行大量的计算块,从而节省了许多顺序的 kernel 调用。我们的一些优化器甚至有更快的 fused 实现,它们将大量的计算块融合到一个 kernel 中。我们可以将 foreach 实现视为水平融合,而 fused 实现则在此基础上进行垂直融合。

一般来说,这 3 种实现的性能排序是 fused > foreach > for-loop。因此,在适用情况下,我们默认使用 foreach 而非 for 循环。适用情况意味着 foreach 实现可用,用户未指定任何实现特定的关键字参数(例如 fused、foreach、differentiable),并且所有 tensor 都是原生的。请注意,虽然 fused 实现应该比 foreach 更快,但这些实现较新,我们希望在全面推广之前给予它们更多的测试时间。下表总结了每种实现的稳定性状态,欢迎您尝试!

下表显示了每种算法的可用实现和默认实现

算法

默认

是否支持 foreach?

是否支持 fused?

Adadelta

foreach

Adafactor

for-loop

Adagrad

foreach

是 (仅 CPU)

Adam

foreach

AdamW

foreach

SparseAdam

for-loop

Adamax

foreach

ASGD

foreach

LBFGS

for-loop

NAdam

foreach

RAdam

foreach

RMSprop

foreach

Rprop

foreach

SGD

foreach

下表显示了 fused 实现的稳定性状态

算法

CPU

CUDA

MPS

Adadelta

不支持

不支持

不支持

Adafactor

不支持

不支持

不支持

Adagrad

Beta

不支持

不支持

Adam

Beta

稳定

Beta

AdamW

Beta

稳定

Beta

SparseAdam

不支持

不支持

不支持

Adamax

不支持

不支持

不支持

ASGD

不支持

不支持

不支持

LBFGS

不支持

不支持

不支持

NAdam

不支持

不支持

不支持

RAdam

不支持

不支持

不支持

RMSprop

不支持

不支持

不支持

Rprop

不支持

不支持

不支持

SGD

Beta

Beta

Beta

如何调整学习率

torch.optim.lr_scheduler.LRScheduler 提供了几种根据 epoch 数调整学习率的方法。torch.optim.lr_scheduler.ReduceLROnPlateau 允许根据一些验证指标动态降低学习率。

学习率调度应在优化器更新之后应用;例如,您应该这样编写代码

示例

optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
scheduler = ExponentialLR(optimizer, gamma=0.9)

for epoch in range(20):
    for input, target in dataset:
        optimizer.zero_grad()
        output = model(input)
        loss = loss_fn(output, target)
        loss.backward()
        optimizer.step()
    scheduler.step()

大多数学习率调度器可以串联调用(也称为链式调度器)。结果是每个调度器依次应用于前一个调度器获得的学习率。

示例

optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
scheduler1 = ExponentialLR(optimizer, gamma=0.9)
scheduler2 = MultiStepLR(optimizer, milestones=[30,80], gamma=0.1)

for epoch in range(20):
    for input, target in dataset:
        optimizer.zero_grad()
        output = model(input)
        loss = loss_fn(output, target)
        loss.backward()
        optimizer.step()
    scheduler1.step()
    scheduler2.step()

在文档的许多地方,我们将使用以下模板来指代调度器算法。

>>> scheduler = ...
>>> for epoch in range(100):
>>>     train(...)
>>>     validate(...)
>>>     scheduler.step()

警告

在 PyTorch 1.1.0 之前,学习率调度器预期在优化器更新之前调用;1.1.0 以 BC 兼容性断裂的方式改变了这种行为。如果您在优化器更新(调用 optimizer.step())之前使用学习率调度器(调用 scheduler.step()),这将跳过学习率调度的第一个值。如果您在升级到 PyTorch 1.1.0 后无法复现结果,请检查您是否在错误的时间调用了 scheduler.step()

lr_scheduler.LRScheduler

在优化过程中调整学习率。

lr_scheduler.LambdaLR

设置初始学习率。

lr_scheduler.MultiplicativeLR

将每个参数组的学习率乘以指定函数中给定的因子。

lr_scheduler.StepLR

每经过 step_size 个 epoch,将每个参数组的学习率乘以 gamma 进行衰减。

lr_scheduler.MultiStepLR

一旦 epoch 数达到某个里程碑,将每个参数组的学习率乘以 gamma 进行衰减。

lr_scheduler.ConstantLR

将每个参数组的学习率乘以一个小的常数因子。

lr_scheduler.LinearLR

通过线性改变一个小的乘法因子来衰减每个参数组的学习率。

lr_scheduler.ExponentialLR

每经过一个 epoch,将每个参数组的学习率乘以 gamma 进行衰减。

lr_scheduler.PolynomialLR

在给定的 total_iters 中,使用多项式函数衰减每个参数组的学习率。

lr_scheduler.CosineAnnealingLR

使用余弦退火调度设置每个参数组的学习率。

lr_scheduler.ChainedScheduler

链接一个学习率调度器列表。

lr_scheduler.SequentialLR

包含一个预期在优化过程中按顺序调用的调度器列表。

lr_scheduler.ReduceLROnPlateau

当某个指标停止改进时降低学习率。

lr_scheduler.CyclicLR

根据周期性学习率策略 (CLR) 设置每个参数组的学习率。

lr_scheduler.OneCycleLR

根据 1cycle 学习率策略设置每个参数组的学习率。

lr_scheduler.CosineAnnealingWarmRestarts

使用余弦退火调度设置每个参数组的学习率。

如何利用命名参数加载优化器 state dict

函数 load_state_dict() 如果存在,则会存储加载的 state dict 中的可选 param_names 内容。然而,加载优化器状态的过程本身不受影响,因为参数的顺序对于保持兼容性至关重要(以防顺序不同)。要利用加载的 state dict 中的参数名称,需要根据期望的行为实现一个自定义的 register_load_state_dict_pre_hook

例如,当模型架构发生变化,但权重和优化器状态需要保持不变时,这会很有用。以下示例演示了如何实现此自定义。

示例

class OneLayerModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(3, 4)

    def forward(self, x):
        return self.fc(x)

model = OneLayerModel()
optimizer = optim.SGD(model.named_parameters(), lr=0.01, momentum=0.9)
# training..
torch.save(optimizer.state_dict(), PATH)

假设 model 实现了一个专家 (MoE),我们希望复制它并恢复训练以得到两个专家,这两个专家都以与 fc 层相同的方式初始化。对于下面的 model2,我们创建了两个与 fc 相同的层,并通过将 model 的模型权重和优化器状态加载到 model2fc1fc2 中(并相应地调整它们)来恢复训练。

class TwoLayerModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(3, 4)
        self.fc2 = nn.Linear(3, 4)

    def forward(self, x):
        return (self.fc1(x) + self.fc2(x)) / 2

model2 = TwoLayerModel()
# adapt and load model weights..
optimizer2 = optim.SGD(model2.named_parameters(), lr=0.01, momentum=0.9)

要使用前一个优化器的 state dict 加载 optimizer2 的 state dict,以便 fc1fc2 都将使用 fc 优化器状态的副本进行初始化(以从 fc 恢复每个层的训练),我们可以使用以下钩子

def adapt_state_dict_ids(optimizer, state_dict):
    adapted_state_dict = deepcopy(optimizer.state_dict())
    # Copy setup parameters (lr, weight_decay, etc.), in case they differ in the loaded state dict.
    for k, v in state_dict['param_groups'][0].items():
        if k not in ['params', 'param_names']:
            adapted_state_dict['param_groups'][0][k] = v

    lookup_dict = {
        'fc1.weight': 'fc.weight',
        'fc1.bias': 'fc.bias',
        'fc2.weight': 'fc.weight',
        'fc2.bias': 'fc.bias'
    }
    clone_deepcopy = lambda d: {k: (v.clone() if isinstance(v, torch.Tensor) else deepcopy(v)) for k, v in d.items()}
    for param_id, param_name in zip(
            optimizer.state_dict()['param_groups'][0]['params'],
            optimizer.state_dict()['param_groups'][0]['param_names']):
        name_in_loaded = lookup_dict[param_name]
        index_in_loaded_list = state_dict['param_groups'][0]['param_names'].index(name_in_loaded)
        id_in_loaded = state_dict['param_groups'][0]['params'][index_in_loaded_list]
        # Copy the state of the corresponding parameter
        if id_in_loaded in state_dict['state']:
            adapted_state_dict['state'][param_id] = clone_deepcopy(state_dict['state'][id_in_loaded])

    return adapted_state_dict

optimizer2.register_load_state_dict_pre_hook(adapt_state_dict_ids)
optimizer2.load_state_dict(torch.load(PATH)) # The previous optimizer saved state_dict

这确保了在模型加载期间,将使用调整后的 state_dict 来为 model2 的层设置正确的状态。请注意,此代码是专门为本示例设计的(例如,假设只有一个参数组),其他情况可能需要不同的调整。

以下示例展示了在模型结构变化时,如何处理加载的 state dict 中缺失的参数。Model_bypass 添加了一个新的 bypass 层,这在原始的 Model1 中不存在。为了恢复训练,使用了一个自定义的 adapt_state_dict_missing_param 钩子来调整优化器的 state_dict,确保现有参数被正确映射,而缺失的参数(如 bypass 层)保持不变(如本示例中初始化的一样)。这种方法使得即使模型发生变化,也能平滑加载和恢复优化器状态。新的 bypass 层将从头开始训练。

class Model1(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(5, 5)

    def forward(self, x):
        return self.fc(x) + x


model = Model1()
optimizer = optim.SGD(model.named_parameters(), lr=0.01, momentum=0.9)
# training..
torch.save(optimizer.state_dict(), PATH)

class Model_bypass(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(5, 5)
        self.bypass = nn.Linear(5, 5, bias=False)
        torch.nn.init.eye_(self.bypass.weight)

    def forward(self, x):
        return self.fc(x) + self.bypass(x)

model2 = Model_bypass()
optimizer2 = optim.SGD(model2.named_parameters(), lr=0.01, momentum=0.9)

def adapt_state_dict_missing_param(optimizer, state_dict):
    adapted_state_dict = deepcopy(optimizer.state_dict())
    # Copy setup parameters (lr, weight_decay, etc.), in case they differ in the loaded state dict.
    for k, v in state_dict['param_groups'][0].items():
        if k not in ['params', 'param_names']:
            adapted_state_dict['param_groups'][0][k] = v

    lookup_dict = {
        'fc.weight': 'fc.weight',
        'fc.bias': 'fc.bias',
        'bypass.weight': None,
    }

    clone_deepcopy = lambda d: {k: (v.clone() if isinstance(v, torch.Tensor) else deepcopy(v)) for k, v in d.items()}
    for param_id, param_name in zip(
            optimizer.state_dict()['param_groups'][0]['params'],
            optimizer.state_dict()['param_groups'][0]['param_names']):
        name_in_loaded = lookup_dict[param_name]
        if name_in_loaded in state_dict['param_groups'][0]['param_names']:
            index_in_loaded_list = state_dict['param_groups'][0]['param_names'].index(name_in_loaded)
            id_in_loaded = state_dict['param_groups'][0]['params'][index_in_loaded_list]
            # Copy the state of the corresponding parameter
            if id_in_loaded in state_dict['state']:
                adapted_state_dict['state'][param_id] = clone_deepcopy(state_dict['state'][id_in_loaded])

    return adapted_state_dict

optimizer2.register_load_state_dict_pre_hook(adapt_state_dict_ids)
optimizer2.load_state_dict(torch.load(PATH)) # The previous optimizer saved state_dict

作为第三个示例,可以使用此钩子根据参数名称加载状态,而不是根据参数顺序加载(默认方法)。

def names_matching(optimizer, state_dict):
    assert len(state_dict['param_groups']) == len(optimizer.state_dict()['param_groups'])
    adapted_state_dict = deepcopy(optimizer.state_dict())
    for g_ind in range(len(state_dict['param_groups'])):
        assert len(state_dict['param_groups'][g_ind]['params']) == len(
            optimizer.state_dict()['param_groups'][g_ind]['params'])

        for k, v in state_dict['param_groups'][g_ind].items():
            if k not in ['params', 'param_names']:
                adapted_state_dict['param_groups'][g_ind][k] = v

        for param_id, param_name in zip(
                optimizer.state_dict()['param_groups'][g_ind]['params'],
                optimizer.state_dict()['param_groups'][g_ind]['param_names']):
            index_in_loaded_list = state_dict['param_groups'][g_ind]['param_names'].index(param_name)
            id_in_loaded = state_dict['param_groups'][g_ind]['params'][index_in_loaded_list]
            # Copy the state of the corresponding parameter
            if id_in_loaded in state_dict['state']:
                adapted_state_dict['state'][param_id] = deepcopy(state_dict['state'][id_in_loaded])

    return adapted_state_dict

权重平均 (SWA 和 EMA)

torch.optim.swa_utils.AveragedModel 实现了随机权重平均 (SWA) 和指数移动平均 (EMA),torch.optim.swa_utils.SWALR 实现了 SWA 学习率调度器,而 torch.optim.swa_utils.update_bn() 是一个用于在训练结束时更新 SWA/EMA 批归一化统计信息的实用函数。

SWA 源自论文 Averaging Weights Leads to Wider Optima and Better Generalization

EMA 是一种广为人知的技术,通过减少所需的权重更新次数来缩短训练时间。它是 Polyak 平均的一种变体,但使用指数权重而非在迭代中分配等权重。

构建平均模型

`AveragedModel` 类用于计算 SWA 或 EMA 模型的权重。

您可以通过运行以下代码创建一个 SWA 平均模型

>>> averaged_model = AveragedModel(model)

通过如下指定 multi_avg_fn 参数来构建 EMA 模型

>>> decay = 0.999
>>> averaged_model = AveragedModel(model, multi_avg_fn=get_ema_multi_avg_fn(decay))

`Decay` 是一个介于 0 到 1 之间的参数,用于控制平均参数的衰减速度。如果未提供给 torch.optim.swa_utils.get_ema_multi_avg_fn(),默认值为 0.999。Decay 值应接近 1.0,因为较小的值可能导致优化收敛问题。

torch.optim.swa_utils.get_ema_multi_avg_fn() 返回一个函数,该函数将以下 EMA 方程应用于权重

Wt+1EMA=αWtEMA+(1α)WtmodelW^\textrm{EMA}_{t+1} = \alpha W^\textrm{EMA}_{t} + (1 - \alpha) W^\textrm{model}_t

其中 alpha 是 EMA 衰减。

此处,模型 model 可以是任意 torch.nn.Module 对象。averaged_model 将跟踪 model 参数的运行平均值。要更新这些平均值,您应该在 optimizer.step() 之后使用 update_parameters() 函数

>>> averaged_model.update_parameters(model)

对于 SWA 和 EMA,此调用通常紧随优化器 step() 之后执行。对于 SWA,在训练开始时通常会跳过一定的步数。

自定义平均策略

默认情况下,torch.optim.swa_utils.AveragedModel 计算您提供的参数的运行等权重平均值,但您也可以使用 avg_fnmulti_avg_fn 参数来使用自定义平均函数

  • avg_fn 允许定义一个函数,该函数对每个参数元组(平均参数,模型参数)进行操作,并应返回新的平均参数。

  • multi_avg_fn 允许定义更高效的操作,这些操作同时作用于参数列表的元组(平均参数列表,模型参数列表),例如使用 torch._foreach* 函数。此函数必须就地更新平均参数。

在以下示例中,ema_model 使用 avg_fn 参数计算指数移动平均(EMA)。

>>> ema_avg = lambda averaged_model_parameter, model_parameter, num_averaged:\
>>>         0.9 * averaged_model_parameter + 0.1 * model_parameter
>>> ema_model = torch.optim.swa_utils.AveragedModel(model, avg_fn=ema_avg)

在以下示例中,ema_model 使用更高效的 multi_avg_fn 参数计算指数移动平均(EMA)。

>>> ema_model = AveragedModel(model, multi_avg_fn=get_ema_multi_avg_fn(0.9))

SWA 学习率调度器

通常,在 SWA 中,学习率设置为一个较高的常数值。SWALR 是一个学习率调度器,它将学习率退火到固定值,然后保持不变。例如,以下代码创建了一个调度器,它在每个参数组中,在 5 个 epoch 内将学习率从初始值线性退火到 0.05

>>> swa_scheduler = torch.optim.swa_utils.SWALR(optimizer, \
>>>         anneal_strategy="linear", anneal_epochs=5, swa_lr=0.05)

您也可以通过设置 anneal_strategy="cos" 来使用余弦退火到固定值,而不是线性退火。

处理批归一化

update_bn() 是一个实用函数,它允许在训练结束时,在给定的 dataloader loader 上计算 SWA 模型的批归一化统计信息。

>>> torch.optim.swa_utils.update_bn(loader, swa_model)

update_bn()swa_model 应用于 dataloader 中的每个元素,并计算模型中每个批归一化层的激活统计信息。

警告

update_bn() 假定 dataloader loader 中的每个批次要么是一个 tensor,要么是一个 tensor 列表,其中第一个元素是网络 swa_model 应该应用于的 tensor。如果您的 dataloader 具有不同的结构,您可以通过使用 swa_model 对数据集的每个元素执行一次前向传播来更新 swa_model 的批归一化统计信息。

综合使用:SWA

在下面的示例中,swa_model 是 SWA 模型,它累积权重的平均值。我们总共训练模型 300 个 epoch,并在 epoch 160 时切换到 SWA 学习率调度器并开始收集参数的 SWA 平均值。

>>> loader, optimizer, model, loss_fn = ...
>>> swa_model = torch.optim.swa_utils.AveragedModel(model)
>>> scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=300)
>>> swa_start = 160
>>> swa_scheduler = SWALR(optimizer, swa_lr=0.05)
>>>
>>> for epoch in range(300):
>>>       for input, target in loader:
>>>           optimizer.zero_grad()
>>>           loss_fn(model(input), target).backward()
>>>           optimizer.step()
>>>       if epoch > swa_start:
>>>           swa_model.update_parameters(model)
>>>           swa_scheduler.step()
>>>       else:
>>>           scheduler.step()
>>>
>>> # Update bn statistics for the swa_model at the end
>>> torch.optim.swa_utils.update_bn(loader, swa_model)
>>> # Use swa_model to make predictions on test data
>>> preds = swa_model(test_input)

综合使用:EMA

在下面的示例中,ema_model 是 EMA 模型,它以 0.999 的衰减率累积权重的指数衰减平均值。我们总共训练模型 300 个 epoch 并立即开始收集 EMA 平均值。

>>> loader, optimizer, model, loss_fn = ...
>>> ema_model = torch.optim.swa_utils.AveragedModel(model, \
>>>             multi_avg_fn=torch.optim.swa_utils.get_ema_multi_avg_fn(0.999))
>>>
>>> for epoch in range(300):
>>>       for input, target in loader:
>>>           optimizer.zero_grad()
>>>           loss_fn(model(input), target).backward()
>>>           optimizer.step()
>>>           ema_model.update_parameters(model)
>>>
>>> # Update bn statistics for the ema_model at the end
>>> torch.optim.swa_utils.update_bn(loader, ema_model)
>>> # Use ema_model to make predictions on test data
>>> preds = ema_model(test_input)

swa_utils.AveragedModel

实现了用于随机权重平均 (SWA) 和指数移动平均 (EMA) 的平均模型。

swa_utils.SWALR

将每个参数组中的学习率退火到固定值。

torch.optim.swa_utils.get_ema_multi_avg_fn(decay=0.999)[source][source]

获取对多个参数应用指数移动平均 (EMA) 的函数。

torch.optim.swa_utils.update_bn(loader, model, device=None)[source][source]

更新模型中的 BatchNorm running_mean, running_var 缓冲区。

它对 loader 中的数据执行一次遍历,以估计模型中 BatchNorm 层的激活统计信息。

参数
  • loader (torch.utils.data.DataLoader) – 用于计算激活统计信息的数据集加载器。每个数据批次应该是一个 tensor,或者一个列表/元组,其第一个元素是包含数据的 tensor。

  • model (torch.nn.Module) – 需要更新 BatchNorm 统计信息的模型。

  • device (torch.device, optional) – 如果设置,数据在传递给 device 之前将转移到 model

示例

>>> loader, model = ...
>>> torch.optim.swa_utils.update_bn(loader, model)

注意

update_bn 实用工具假定 loader 中的每个数据批次要么是一个 tensor,要么是一个 tensor 的列表或元组;在后一种情况下,假定应在对应于数据批次的列表或元组的第一个元素上调用 model.forward()

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取适合初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得问题解答

查看资源