快捷方式

使用 PyTorch 将多个模型保存到一个文件中

保存和加载多个模型可以帮助您重复使用之前训练过的模型。

简介

保存包含多个 torch.nn.Modules 的模型(例如 GAN、序列到序列模型或模型集成)时,您必须保存每个模型的 state_dict 和相应优化器的字典。您还可以保存任何可能帮助您恢复训练的其他项目,只需将它们追加到字典中即可。要加载模型,首先初始化模型和优化器,然后使用 torch.load() 在本地加载字典。从此处,您可以通过按预期查询字典轻松访问保存的项目。在本食谱中,我们将演示如何使用 PyTorch 将多个模型保存到一个文件中。

设置

在我们开始之前,我们需要安装 torch(如果尚未安装)。

pip install torch

步骤

  1. 导入所有加载数据的必要库

  2. 定义并初始化神经网络

  3. 初始化优化器

  4. 保存多个模型

  5. 加载多个模型

1. 导入加载数据的必要库

在本教程中,我们将使用 torch 及其子库 torch.nntorch.optim

import torch
import torch.nn as nn
import torch.optim as optim

2. 定义并初始化神经网络

为了举例说明,我们将创建一个用于训练图像的神经网络。要了解更多信息,请参阅“定义神经网络”教程。创建两个模型变量,以便最终保存。

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

netA = Net()
netB = Net()

3. 初始化优化器

我们将使用带动量的 SGD 为我们创建的每个模型构建一个优化器。

optimizerA = optim.SGD(netA.parameters(), lr=0.001, momentum=0.9)
optimizerB = optim.SGD(netB.parameters(), lr=0.001, momentum=0.9)

4. 保存多个模型

收集所有相关信息并构建您的字典。

# Specify a path to save to
PATH = "model.pt"

torch.save({
            'modelA_state_dict': netA.state_dict(),
            'modelB_state_dict': netB.state_dict(),
            'optimizerA_state_dict': optimizerA.state_dict(),
            'optimizerB_state_dict': optimizerB.state_dict(),
            }, PATH)

4. 加载多个模型

请记住,首先初始化模型和优化器,然后在本地加载字典。

modelA = Net()
modelB = Net()
optimModelA = optim.SGD(modelA.parameters(), lr=0.001, momentum=0.9)
optimModelB = optim.SGD(modelB.parameters(), lr=0.001, momentum=0.9)

checkpoint = torch.load(PATH, weights_only=True)
modelA.load_state_dict(checkpoint['modelA_state_dict'])
modelB.load_state_dict(checkpoint['modelB_state_dict'])
optimizerA.load_state_dict(checkpoint['optimizerA_state_dict'])
optimizerB.load_state_dict(checkpoint['optimizerB_state_dict'])

modelA.eval()
modelB.eval()
# - or -
modelA.train()
modelB.train()

在运行推理之前,必须调用 model.eval() 将 dropout 和批归一化层设置为评估模式。如果不这样做,会导致推理结果不一致。

如果您希望恢复训练,请调用 model.train() 以确保这些层处于训练模式。

恭喜!您已成功在 PyTorch 中保存和加载多个模型。

脚本总运行时间:(0 分钟 0.000 秒)

由 Sphinx-Gallery 生成的图库

文档

访问 PyTorch 的全面开发人员文档

View Docs

教程

获取针对初学者和高级开发人员的深入教程

View Tutorials

资源

查找开发资源并获得问题的解答

View Resources