注意
点击此处下载完整示例代码
在 PyTorch 中使用不同模型的参数温启动模型¶
创建日期:2020 年 4 月 20 日 | 最后更新:2024 年 8 月 27 日 | 最后验证:2024 年 11 月 5 日
在迁移学习或训练新的复杂模型时,部分加载模型或加载部分模型是常见场景。利用已训练的参数,即使只有一部分可用,也将有助于温启动训练过程,并有望帮助你的模型比从头训练收敛得快得多。
引言¶
无论是从缺少某些键的部分 state_dict
加载,还是加载一个键多于要加载的模型中的 state_dict
,你都可以将 load_state_dict()
函数中的 strict 参数设置为 False
以忽略不匹配的键。在本篇精选代码 (recipe) 中,我们将实验使用不同模型的参数来温启动模型。
步骤¶
导入加载数据所需的所有库
定义并初始化神经网络 A 和 B
保存模型 A
加载到模型 B 中
1. 导入加载数据所需的库¶
对于本篇精选代码 (recipe),我们将使用 torch
及其附属模块 torch.nn
和 torch.optim
。
import torch
import torch.nn as nn
import torch.optim as optim
2. 定义并初始化神经网络 A 和 B¶
为了示例目的,我们将创建一个用于训练图像的神经网络。要了解更多信息,请参阅 定义神经网络 精选代码 (recipe)。我们将创建两个神经网络,以便将类型 A 的一个参数加载到类型 B 中。
class NetA(nn.Module):
def __init__(self):
super(NetA, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
netA = NetA()
class NetB(nn.Module):
def __init__(self):
super(NetB, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
netB = NetB()
3. 保存模型 A¶
# Specify a path to save to
PATH = "model.pt"
torch.save(netA.state_dict(), PATH)
4. 加载到模型 B 中¶
如果你想将参数从一层加载到另一层,但有些键不匹配,只需更改你正在加载的 state_dict 中参数键的名称,使其与你要加载的模型中的键相匹配。
netB.load_state_dict(torch.load(PATH, weights_only=True), strict=False)
你可以看到所有键都已成功匹配!
恭喜!你已成功地在 PyTorch 中使用不同模型的参数温启动了模型。