注意
点击此处下载完整示例代码
从检查点加载 nn.Module
的技巧¶
创建于:2023 年 10 月 03 日 | 最近更新:2024 年 8 月 27 日 | 最近验证:2024 年 11 月 05 日
如果您正在加载检查点并希望尽可能减少计算和内存使用,本教程将分享一些推荐的实践方法。特别是,我们将讨论
注意
本教程需要 PyTorch 2.1.0 或更高版本。
让我们考虑一个包含线性层列表的简单 nn.Module
import torch
from torch import nn
import time
class SomeModule(torch.nn.Module):
def __init__(self, size):
super().__init__()
self.linears = nn.ModuleList([nn.Linear(size, size) for i in range(10)])
def forward(self, x):
return self.linears(x)
m = SomeModule(1000)
torch.save(m.state_dict(), 'checkpoint.pth')
以下代码片段演示了如何使用 torch.load
的 mmap
关键字参数、torch.device()
上下文管理器以及 nn.Module.load_state_dict()
的 assign
关键字参数。
state_dict = torch.load('checkpoint.pth', mmap=True, weights_only=True)
with torch.device('meta'):
meta_m = SomeModule(1000)
meta_m.load_state_dict(state_dict, assign=True)
<All keys matched successfully>
将下面的代码片段与上面的进行比较
state_dict = torch.load('checkpoint.pth', weights_only=True)
m = SomeModule(1000)
m.load_state_dict(state_dict)
<All keys matched successfully>
第二个示例没有使用上面列出的任何特性,在加载检查点时计算和内存效率较低。在以下章节中,我们将进一步详细讨论每个特性。
使用 torch.load(mmap=True)
¶
首先,让我们考虑使用 torch.load
加载检查点时会发生什么。当我们使用 torch.save
保存检查点时,张量存储(tensor storages)会被标记上保存时所在的设备。使用 torch.load
时,张量存储将被加载到它们被标记的设备上(除非使用 map_location
标志覆盖此行为)。为了便于解释,让我们假设张量保存在 CPU 上。这意味着在第一行,所有张量存储都将加载到 CPU RAM 中,这在以下情况下可能是不希望的:
CPU RAM 小于检查点的大小。
在执行例如一些每张量处理之前,等待整个检查点加载到 RAM 中。
start_time = time.time()
state_dict = torch.load('checkpoint.pth', weights_only=True)
end_time = time.time()
print(f"loading time without mmap={end_time - start_time}")
loading time without mmap=0.02578878402709961
start_time = time.time()
state_dict = torch.load('checkpoint.pth', mmap=True, weights_only=True)
end_time = time.time()
print(f"loading time with mmap={end_time - start_time}")
loading time with mmap=0.006429910659790039
如上所述,可以使用此参数对检查点进行每张量处理,而无需预先将所有张量存储加载到 CPU 内存中。例如
def my_special_routine(t, device):
# this could be a much fancier operation
return t.to(dtype=torch.bfloat16, device=device)
def my_processing_function(key, device):
t = state_dict[key]
processed_t = my_special_routine(t, device)
del t
state_dict[key] = processed_t
for key in state_dict.keys():
device = torch.device('cuda')
my_processing_function(key, device)
使用 torch.device('meta')
¶
接下来,让我们考虑模块的创建。
m = SomeModule(1000)
这会为所有参数/缓冲区分配内存,并根据 SomeModule.__init__()
中定义的默认初始化方案对其进行初始化,这在我们想加载检查点时是浪费的,原因如下:
初始化核函数的结果将被
load_state_dict()
覆盖,而从未被使用,因此初始化是浪费的。我们正在 RAM 中为这些参数/缓冲区分配内存,而加载保存的状态字典的
torch.load
也会在 RAM 中为检查点中的参数/缓冲区分配内存。
为了解决这两个问题,我们可以在实例化 nn.Module()
时使用 torch.device()
上下文管理器并指定 device='meta'
。
with torch.device('meta'):
new_m = SomeModule(1000)
使用 load_state_dict(assign=True)
¶
接下来,我们考虑加载状态字典。
m.load_state_dict(state_dict)
<All keys matched successfully>
然而,就地复制到 meta
设备上的张量是一个空操作(no-op)。为了避免这种情况,我们可以向 load_state_dict()
传递 assign=True
关键字参数。
这里需要注意的一点是,由于优化器持有对 nn.Module.parameters()
的引用,如果在加载模块状态字典时传递了 assign=True
,则优化器必须在模块加载完成后初始化。
# As of PyTorch 2.3.0, one can use ``torch.__future__.set_swap_module_params_on_conversion`` to
# avoid this caveat. This `recipe <https://pytorch.ac.cn/tutorials/recipes/recipes/swap_tensors.html>`_
# provides more details.
new_m.load_state_dict(state_dict, assign=True)
# Before 2.3.0, this MUST be done AFTER the load_state_dict with assign.
# In versions >= 2.3.0, one can consider setting ``torch.__future__.set_swap_module_params_on_conversion``
opt = torch.optim.SGD(new_m.parameters(), lr=1e-3)
结论¶
总而言之,在本教程中,我们学习了 torch.load(mmap=True)
、带有 device=meta
的 torch.device()
上下文管理器以及 nn.Module.load_state_dict(assign=True)
,以及如何在从检查点加载模型时使用这些工具进行辅助。
脚本总运行时间: ( 0 分钟 0.296 秒)