注意
点击此处下载完整的示例代码
加载 nn.Module
检查点的技巧¶
创建于:2023 年 10 月 3 日 | 最后更新:2024 年 8 月 27 日 | 最后验证:2024 年 11 月 5 日
如果您正在加载检查点并希望尽可能减少计算和内存使用量,本教程将分享一些推荐的做法。特别是,我们将讨论
torch.load
上的mmap
关键字参数torch.device()
上下文管理器nn.Module.load_state_dict()
上的assign
关键字参数
注意
此食谱需要 PyTorch 2.1.0 或更高版本。
让我们考虑一个简单的 nn.Module
,其中包含线性层列表
import torch
from torch import nn
import time
class SomeModule(torch.nn.Module):
def __init__(self, size):
super().__init__()
self.linears = nn.ModuleList([nn.Linear(size, size) for i in range(10)])
def forward(self, x):
return self.linears(x)
m = SomeModule(1000)
torch.save(m.state_dict(), 'checkpoint.pth')
以下代码片段演示了 torch.load
的 mmap
关键字参数、torch.device()
上下文管理器和 nn.Module.load_state_dict()
的 assign
关键字参数的用法。
state_dict = torch.load('checkpoint.pth', mmap=True, weights_only=True)
with torch.device('meta'):
meta_m = SomeModule(1000)
meta_m.load_state_dict(state_dict, assign=True)
<All keys matched successfully>
将下面的代码片段与上面的代码片段进行比较
state_dict = torch.load('checkpoint.pth', weights_only=True)
m = SomeModule(1000)
m.load_state_dict(state_dict)
<All keys matched successfully>
第二个示例未使用上面列出的任何功能,并且在加载检查点时计算和内存效率较低。在以下部分中,我们将更详细地讨论每个功能。
使用 torch.load(mmap=True)
¶
首先,让我们考虑一下当我们使用 torch.load
加载检查点时会发生什么。当我们使用 torch.save
保存检查点时,张量存储会标记它们保存的设备。使用 torch.load
,张量存储将加载到它们标记的设备(除非使用 map_location
标志覆盖此行为)。为了便于解释,让我们假设张量保存在 CPU 上。这意味着在第一行,所有张量存储都将加载到 CPU RAM 中,这在以下情况下可能是不希望的
CPU RAM 小于检查点的大小。
在执行例如某些每张量处理之前,等待整个检查点加载到 RAM 中。
start_time = time.time()
state_dict = torch.load('checkpoint.pth', weights_only=True)
end_time = time.time()
print(f"loading time without mmap={end_time - start_time}")
loading time without mmap=0.01737356185913086
torch.load
的 mmap
关键字参数尝试解决上述两个问题。顾名思义,torch.load
的 mmap
关键字参数使用 mmap 调用,该调用将磁盘上的文件映射到虚拟内存中,并让操作系统自动处理加载和卸载到物理内存中。当传递此标志时,张量存储将被内存映射。
start_time = time.time()
state_dict = torch.load('checkpoint.pth', mmap=True, weights_only=True)
end_time = time.time()
print(f"loading time with mmap={end_time - start_time}")
loading time with mmap=0.003757953643798828
如上所述,可以使用此参数对检查点执行每张量处理,而无需预先将所有张量存储加载到 CPU 内存中。例如
def my_special_routine(t, device):
# this could be a much fancier operation
return t.to(dtype=torch.bfloat16, device=device)
def my_processing_function(key, device):
t = state_dict[key]
processed_t = my_special_routine(t, device)
del t
state_dict[key] = processed_t
for key in state_dict.keys():
device = torch.device('cuda')
my_processing_function(key, device)
使用 torch.device('meta')
¶
接下来,让我们考虑模块的创建。
m = SomeModule(1000)
这会为所有参数/缓冲区分配内存,并按照 SomeModule.__init__()
中定义的默认初始化方案对其进行初始化,当我们想要加载检查点时,这是浪费的,原因如下
初始化内核的结果将被
load_state_dict()
覆盖,而从未被使用,因此初始化是浪费的。我们在 RAM 中为这些参数/缓冲区分配内存,而保存的状态字典的
torch.load
也为检查点中的参数/缓冲区分配 RAM 中的内存。
为了解决这两个问题,当实例化 nn.Module()
时,我们可以使用 device='meta'
的 torch.device()
上下文管理器。
torch.device() 上下文管理器确保工厂调用将像传递指定的 device
作为参数一样执行。torch.device('meta')
上的张量不携带数据。但是,它们拥有张量携带的所有其他元数据,例如 .size()
、.stride()
、.requires_grad
等。
with torch.device('meta'):
new_m = SomeModule(1000)
使用 load_state_dict(assign=True)
¶
接下来,我们考虑状态字典的加载。
m.load_state_dict(state_dict)
<All keys matched successfully>
nn.Module.load_state_dict()
通常通过就地 param_in_model.copy_(param_in_state_dict)
实现。这意味着状态字典中具有相应键的参数/缓冲区将复制到 nn.Module
中的参数/缓冲区中。
但是,就地复制到 meta
设备上的张量是空操作。为了避免这种情况,我们可以将 assign=True
关键字参数传递给 load_state_dict()
。
这里需要注意的是,由于优化器持有对 nn.Module.parameters()
的引用,因此如果传递了 assign=True
,则必须在从状态字典加载模块后初始化优化器。
# As of PyTorch 2.3.0, one can use ``torch.__future__.set_swap_module_params_on_conversion`` to
# avoid this caveat. This `recipe <https://pytorch.ac.cn/tutorials/recipes/recipes/swap_tensors.html>`_
# provides more details.
new_m.load_state_dict(state_dict, assign=True)
# Before 2.3.0, this MUST be done AFTER the load_state_dict with assign.
# In versions >= 2.3.0, one can consider setting ``torch.__future__.set_swap_module_params_on_conversion``
opt = torch.optim.SGD(new_m.parameters(), lr=1e-3)
结论¶
总而言之,在本教程中,我们学习了 torch.load(mmap=True)
、带有 device=meta
的 torch.device()
上下文管理器和 nn.Module.load_state_dict(assign=True)
,以及如何在从检查点加载模型时使用这些工具。
脚本的总运行时间: ( 0 分钟 0.342 秒)