注意
单击 此处 下载完整的示例代码
使用 PyTorch 将多个模型保存到一个文件中¶
保存和加载多个模型可以帮助您重复使用之前训练过的模型。
简介¶
保存包含多个 torch.nn.Modules
的模型(例如 GAN、序列到序列模型或模型集成)时,您必须保存每个模型的 state_dict 和相应优化器的字典。您还可以保存任何可能帮助您恢复训练的其他项目,只需将它们追加到字典中即可。要加载模型,首先初始化模型和优化器,然后使用 torch.load()
在本地加载字典。从此处,您可以通过按预期查询字典轻松访问保存的项目。在本食谱中,我们将演示如何使用 PyTorch 将多个模型保存到一个文件中。
步骤¶
导入所有加载数据的必要库
定义并初始化神经网络
初始化优化器
保存多个模型
加载多个模型
1. 导入加载数据的必要库¶
在本教程中,我们将使用 torch
及其子库 torch.nn
和 torch.optim
。
import torch
import torch.nn as nn
import torch.optim as optim
2. 定义并初始化神经网络¶
为了举例说明,我们将创建一个用于训练图像的神经网络。要了解更多信息,请参阅“定义神经网络”教程。创建两个模型变量,以便最终保存。
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
netA = Net()
netB = Net()
3. 初始化优化器¶
我们将使用带动量的 SGD 为我们创建的每个模型构建一个优化器。
4. 保存多个模型¶
收集所有相关信息并构建您的字典。
# Specify a path to save to
PATH = "model.pt"
torch.save({
'modelA_state_dict': netA.state_dict(),
'modelB_state_dict': netB.state_dict(),
'optimizerA_state_dict': optimizerA.state_dict(),
'optimizerB_state_dict': optimizerB.state_dict(),
}, PATH)
4. 加载多个模型¶
请记住,首先初始化模型和优化器,然后在本地加载字典。
modelA = Net()
modelB = Net()
optimModelA = optim.SGD(modelA.parameters(), lr=0.001, momentum=0.9)
optimModelB = optim.SGD(modelB.parameters(), lr=0.001, momentum=0.9)
checkpoint = torch.load(PATH, weights_only=True)
modelA.load_state_dict(checkpoint['modelA_state_dict'])
modelB.load_state_dict(checkpoint['modelB_state_dict'])
optimizerA.load_state_dict(checkpoint['optimizerA_state_dict'])
optimizerB.load_state_dict(checkpoint['optimizerB_state_dict'])
modelA.eval()
modelB.eval()
# - or -
modelA.train()
modelB.train()
在运行推理之前,必须调用 model.eval()
将 dropout 和批归一化层设置为评估模式。如果不这样做,会导致推理结果不一致。
如果您希望恢复训练,请调用 model.train()
以确保这些层处于训练模式。
恭喜!您已成功在 PyTorch 中保存和加载多个模型。
脚本总运行时间:(0 分钟 0.000 秒)