快捷方式

在 PyTorch 中使用不同模型的参数温启动模型

创建日期:2020 年 4 月 20 日 | 最后更新:2024 年 8 月 27 日 | 最后验证:2024 年 11 月 5 日

在迁移学习或训练新的复杂模型时,部分加载模型或加载部分模型是常见场景。利用已训练的参数,即使只有一部分可用,也将有助于温启动训练过程,并有望帮助你的模型比从头训练收敛得快得多。

引言

无论是从缺少某些键的部分 state_dict 加载,还是加载一个键多于要加载的模型中的 state_dict,你都可以将 load_state_dict() 函数中的 strict 参数设置为 False 以忽略不匹配的键。在本篇精选代码 (recipe) 中,我们将实验使用不同模型的参数来温启动模型。

设置

在开始之前,如果尚未安装 torch,我们需要先安装它。

pip install torch

步骤

  1. 导入加载数据所需的所有库

  2. 定义并初始化神经网络 A 和 B

  3. 保存模型 A

  4. 加载到模型 B 中

1. 导入加载数据所需的库

对于本篇精选代码 (recipe),我们将使用 torch 及其附属模块 torch.nntorch.optim

import torch
import torch.nn as nn
import torch.optim as optim

2. 定义并初始化神经网络 A 和 B

为了示例目的,我们将创建一个用于训练图像的神经网络。要了解更多信息,请参阅 定义神经网络 精选代码 (recipe)。我们将创建两个神经网络,以便将类型 A 的一个参数加载到类型 B 中。

class NetA(nn.Module):
    def __init__(self):
        super(NetA, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

netA = NetA()

class NetB(nn.Module):
    def __init__(self):
        super(NetB, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

netB = NetB()

3. 保存模型 A

# Specify a path to save to
PATH = "model.pt"

torch.save(netA.state_dict(), PATH)

4. 加载到模型 B 中

如果你想将参数从一层加载到另一层,但有些键不匹配,只需更改你正在加载的 state_dict 中参数键的名称,使其与你要加载的模型中的键相匹配。

netB.load_state_dict(torch.load(PATH, weights_only=True), strict=False)

你可以看到所有键都已成功匹配!

恭喜!你已成功地在 PyTorch 中使用不同模型的参数温启动了模型。

文档

查阅 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得问题解答

查看资源