• 教程 >
  • 从检查点加载 nn.Module 的技巧
快捷方式

从检查点加载 nn.Module 的技巧

创建于:2023 年 10 月 03 日 | 最近更新:2024 年 8 月 27 日 | 最近验证:2024 年 11 月 05 日

作者: Mikayla Gawarecki

如果您正在加载检查点并希望尽可能减少计算和内存使用,本教程将分享一些推荐的实践方法。特别是,我们将讨论

注意

本教程需要 PyTorch 2.1.0 或更高版本。

让我们考虑一个包含线性层列表的简单 nn.Module

import torch
from torch import nn
import time

class SomeModule(torch.nn.Module):
    def __init__(self, size):
        super().__init__()
        self.linears = nn.ModuleList([nn.Linear(size, size) for i in range(10)])

    def forward(self, x):
        return self.linears(x)


m = SomeModule(1000)
torch.save(m.state_dict(), 'checkpoint.pth')

以下代码片段演示了如何使用 torch.loadmmap 关键字参数、torch.device() 上下文管理器以及 nn.Module.load_state_dict()assign 关键字参数。

state_dict = torch.load('checkpoint.pth', mmap=True, weights_only=True)
with torch.device('meta'):
  meta_m = SomeModule(1000)
meta_m.load_state_dict(state_dict, assign=True)
<All keys matched successfully>

将下面的代码片段与上面的进行比较

state_dict = torch.load('checkpoint.pth', weights_only=True)
m = SomeModule(1000)
m.load_state_dict(state_dict)
<All keys matched successfully>

第二个示例没有使用上面列出的任何特性,在加载检查点时计算和内存效率较低。在以下章节中,我们将进一步详细讨论每个特性。

使用 torch.load(mmap=True)

首先,让我们考虑使用 torch.load 加载检查点时会发生什么。当我们使用 torch.save 保存检查点时,张量存储(tensor storages)会被标记上保存时所在的设备。使用 torch.load 时,张量存储将被加载到它们被标记的设备上(除非使用 map_location 标志覆盖此行为)。为了便于解释,让我们假设张量保存在 CPU 上。这意味着在第一行,所有张量存储都将加载到 CPU RAM 中,这在以下情况下可能是不希望的:

  • CPU RAM 小于检查点的大小。

  • 在执行例如一些每张量处理之前,等待整个检查点加载到 RAM 中。

start_time = time.time()
state_dict = torch.load('checkpoint.pth', weights_only=True)
end_time = time.time()
print(f"loading time without mmap={end_time - start_time}")
loading time without mmap=0.02578878402709961

start_time = time.time()
state_dict = torch.load('checkpoint.pth', mmap=True, weights_only=True)
end_time = time.time()
print(f"loading time with mmap={end_time - start_time}")
loading time with mmap=0.006429910659790039

如上所述,可以使用此参数对检查点进行每张量处理,而无需预先将所有张量存储加载到 CPU 内存中。例如

def my_special_routine(t, device):
    # this could be a much fancier operation
    return t.to(dtype=torch.bfloat16, device=device)

def my_processing_function(key, device):
    t = state_dict[key]
    processed_t = my_special_routine(t, device)
    del t
    state_dict[key] = processed_t

for key in state_dict.keys():
    device = torch.device('cuda')
    my_processing_function(key, device)

使用 torch.device('meta')

接下来,让我们考虑模块的创建。

m = SomeModule(1000)

这会为所有参数/缓冲区分配内存,并根据 SomeModule.__init__() 中定义的默认初始化方案对其进行初始化,这在我们想加载检查点时是浪费的,原因如下:

  • 初始化核函数的结果将被 load_state_dict() 覆盖,而从未被使用,因此初始化是浪费的。

  • 我们正在 RAM 中为这些参数/缓冲区分配内存,而加载保存的状态字典的 torch.load 也会在 RAM 中为检查点中的参数/缓冲区分配内存。

为了解决这两个问题,我们可以在实例化 nn.Module() 时使用 torch.device() 上下文管理器并指定 device='meta'

with torch.device('meta'):
  new_m = SomeModule(1000)

使用 load_state_dict(assign=True)

接下来,我们考虑加载状态字典。

m.load_state_dict(state_dict)
<All keys matched successfully>

然而,就地复制到 meta 设备上的张量是一个空操作(no-op)。为了避免这种情况,我们可以向 load_state_dict() 传递 assign=True 关键字参数。

这里需要注意的一点是,由于优化器持有对 nn.Module.parameters() 的引用,如果在加载模块状态字典时传递了 assign=True,则优化器必须在模块加载完成后初始化。

# As of PyTorch 2.3.0, one can use ``torch.__future__.set_swap_module_params_on_conversion`` to
# avoid this caveat. This `recipe <https://pytorch.ac.cn/tutorials/recipes/recipes/swap_tensors.html>`_
# provides more details.

new_m.load_state_dict(state_dict, assign=True)
# Before 2.3.0, this MUST be done AFTER the load_state_dict with assign.
# In versions >= 2.3.0, one can consider setting ``torch.__future__.set_swap_module_params_on_conversion``
opt = torch.optim.SGD(new_m.parameters(), lr=1e-3)

结论

总而言之,在本教程中,我们学习了 torch.load(mmap=True)、带有 device=metatorch.device() 上下文管理器以及 nn.Module.load_state_dict(assign=True),以及如何在从检查点加载模型时使用这些工具进行辅助。

脚本总运行时间: ( 0 分钟 0.296 秒)

由 Sphinx-Gallery 生成的画廊

文档

获取 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获取问题解答

查看资源