快捷方式

在 PyTorch 中保存和加载通用检查点

保存和加载用于推理或恢复训练的通用检查点模型对于从上次中断的地方继续操作很有帮助。保存通用检查点时,必须保存的不仅仅是模型的 state_dict。重要的是还要保存优化器的 state_dict,因为它包含在模型训练过程中更新的缓冲区和参数。根据您自己的算法,您可能还希望保存的其他项目包括您停止的 epoch、最新记录的训练损失、外部 torch.nn.Embedding 层等等。

简介

要保存多个检查点,必须将它们组织到字典中,并使用 torch.save() 对字典进行序列化。一个常见的 PyTorch 约定是使用 .tar 文件扩展名保存这些检查点。要加载这些项目,首先初始化模型和优化器,然后使用 torch.load() 在本地加载字典。从这里,您可以通过像预期的那样查询字典轻松访问保存的项目。

在本食谱中,我们将探讨如何保存和加载多个检查点。

设置

在开始之前,如果尚未安装 torch,则需要安装它。

pip install torch

步骤

  1. 导入加载数据所需的所有库

  2. 定义并初始化神经网络

  3. 初始化优化器

  4. 保存通用检查点

  5. 加载通用检查点

1. 导入加载数据所需库

在本教程中,我们将使用torch及其子模块torch.nntorch.optim

import torch
import torch.nn as nn
import torch.optim as optim

2. 定义并初始化神经网络

为了举例说明,我们将创建一个用于训练图像的神经网络。要了解更多信息,请参阅“定义神经网络”教程。

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

net = Net()
print(net)

3. 初始化优化器

我们将使用带动量的SGD。

optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

4. 保存通用检查点

收集所有相关信息并构建字典。

# Additional information
EPOCH = 5
PATH = "model.pt"
LOSS = 0.4

torch.save({
            'epoch': EPOCH,
            'model_state_dict': net.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': LOSS,
            }, PATH)

5. 加载通用检查点

请记住,首先初始化模型和优化器,然后在本地加载字典。

model = Net()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

checkpoint = torch.load(PATH, weights_only=True)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

model.eval()
# - or -
model.train()

在运行推理之前,必须调用model.eval()将dropout和批归一化层设置为评估模式。如果不这样做,将导致推理结果不一致。

如果希望恢复训练,请调用model.train()以确保这些层处于训练模式。

恭喜!您已成功保存和加载了一个通用检查点,用于在PyTorch中进行推理和/或恢复训练。

脚本总运行时间:(0分钟0.000秒)

由Sphinx-Gallery生成图库

文档

访问PyTorch的全面开发者文档

查看文档

教程

获取适合初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源