• 教程 >
  • 加载 nn.Module 检查点的技巧
快捷方式

加载 nn.Module 检查点的技巧

创建于:2023 年 10 月 3 日 | 最后更新:2024 年 8 月 27 日 | 最后验证:2024 年 11 月 5 日

作者: Mikayla Gawarecki

如果您正在加载检查点并希望尽可能减少计算和内存使用量,本教程将分享一些推荐的做法。特别是,我们将讨论

  1. torch.load 上的 mmap 关键字参数

  2. torch.device() 上下文管理器

  3. nn.Module.load_state_dict() 上的 assign 关键字参数

注意

此食谱需要 PyTorch 2.1.0 或更高版本。

让我们考虑一个简单的 nn.Module,其中包含线性层列表

import torch
from torch import nn
import time

class SomeModule(torch.nn.Module):
    def __init__(self, size):
        super().__init__()
        self.linears = nn.ModuleList([nn.Linear(size, size) for i in range(10)])

    def forward(self, x):
        return self.linears(x)


m = SomeModule(1000)
torch.save(m.state_dict(), 'checkpoint.pth')

以下代码片段演示了 torch.loadmmap 关键字参数、torch.device() 上下文管理器和 nn.Module.load_state_dict()assign 关键字参数的用法。

state_dict = torch.load('checkpoint.pth', mmap=True, weights_only=True)
with torch.device('meta'):
  meta_m = SomeModule(1000)
meta_m.load_state_dict(state_dict, assign=True)
<All keys matched successfully>

将下面的代码片段与上面的代码片段进行比较

state_dict = torch.load('checkpoint.pth', weights_only=True)
m = SomeModule(1000)
m.load_state_dict(state_dict)
<All keys matched successfully>

第二个示例未使用上面列出的任何功能,并且在加载检查点时计算和内存效率较低。在以下部分中,我们将更详细地讨论每个功能。

使用 torch.load(mmap=True)

首先,让我们考虑一下当我们使用 torch.load 加载检查点时会发生什么。当我们使用 torch.save 保存检查点时,张量存储会标记它们保存的设备。使用 torch.load,张量存储将加载到它们标记的设备(除非使用 map_location 标志覆盖此行为)。为了便于解释,让我们假设张量保存在 CPU 上。这意味着在第一行,所有张量存储都将加载到 CPU RAM 中,这在以下情况下可能是不希望的

  • CPU RAM 小于检查点的大小。

  • 在执行例如某些每张量处理之前,等待整个检查点加载到 RAM 中。

start_time = time.time()
state_dict = torch.load('checkpoint.pth', weights_only=True)
end_time = time.time()
print(f"loading time without mmap={end_time - start_time}")
loading time without mmap=0.01737356185913086

torch.loadmmap 关键字参数尝试解决上述两个问题。顾名思义,torch.loadmmap 关键字参数使用 mmap 调用,该调用将磁盘上的文件映射到虚拟内存中,并让操作系统自动处理加载和卸载到物理内存中。当传递此标志时,张量存储将被内存映射。

start_time = time.time()
state_dict = torch.load('checkpoint.pth', mmap=True, weights_only=True)
end_time = time.time()
print(f"loading time with mmap={end_time - start_time}")
loading time with mmap=0.003757953643798828

如上所述,可以使用此参数对检查点执行每张量处理,而无需预先将所有张量存储加载到 CPU 内存中。例如

def my_special_routine(t, device):
    # this could be a much fancier operation
    return t.to(dtype=torch.bfloat16, device=device)

def my_processing_function(key, device):
    t = state_dict[key]
    processed_t = my_special_routine(t, device)
    del t
    state_dict[key] = processed_t

for key in state_dict.keys():
    device = torch.device('cuda')
    my_processing_function(key, device)

使用 torch.device('meta')

接下来,让我们考虑模块的创建。

m = SomeModule(1000)

这会为所有参数/缓冲区分配内存,并按照 SomeModule.__init__() 中定义的默认初始化方案对其进行初始化,当我们想要加载检查点时,这是浪费的,原因如下

  • 初始化内核的结果将被 load_state_dict() 覆盖,而从未被使用,因此初始化是浪费的。

  • 我们在 RAM 中为这些参数/缓冲区分配内存,而保存的状态字典的 torch.load 也为检查点中的参数/缓冲区分配 RAM 中的内存。

为了解决这两个问题,当实例化 nn.Module() 时,我们可以使用 device='meta'torch.device() 上下文管理器。

torch.device() 上下文管理器确保工厂调用将像传递指定的 device 作为参数一样执行。torch.device('meta') 上的张量不携带数据。但是,它们拥有张量携带的所有其他元数据,例如 .size().stride().requires_grad 等。

with torch.device('meta'):
  new_m = SomeModule(1000)

使用 load_state_dict(assign=True)

接下来,我们考虑状态字典的加载。

m.load_state_dict(state_dict)
<All keys matched successfully>

nn.Module.load_state_dict() 通常通过就地 param_in_model.copy_(param_in_state_dict) 实现。这意味着状态字典中具有相应键的参数/缓冲区将复制到 nn.Module 中的参数/缓冲区中。

但是,就地复制到 meta 设备上的张量是空操作。为了避免这种情况,我们可以将 assign=True 关键字参数传递给 load_state_dict()

这里需要注意的是,由于优化器持有对 nn.Module.parameters() 的引用,因此如果传递了 assign=True,则必须在从状态字典加载模块后初始化优化器。

# As of PyTorch 2.3.0, one can use ``torch.__future__.set_swap_module_params_on_conversion`` to
# avoid this caveat. This `recipe <https://pytorch.ac.cn/tutorials/recipes/recipes/swap_tensors.html>`_
# provides more details.

new_m.load_state_dict(state_dict, assign=True)
# Before 2.3.0, this MUST be done AFTER the load_state_dict with assign.
# In versions >= 2.3.0, one can consider setting ``torch.__future__.set_swap_module_params_on_conversion``
opt = torch.optim.SGD(new_m.parameters(), lr=1e-3)

结论

总而言之,在本教程中,我们学习了 torch.load(mmap=True)、带有 device=metatorch.device() 上下文管理器和 nn.Module.load_state_dict(assign=True),以及如何在从检查点加载模型时使用这些工具。

脚本的总运行时间: ( 0 分钟 0.342 秒)

由 Sphinx-Gallery 生成的图库


评价本教程

© Copyright 2024, PyTorch。

使用 Sphinx 构建,主题由 theme 提供,由 Read the Docs 提供。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得您的问题解答

查看资源