序列化语义¶
本说明介绍了如何在 Python 中保存和加载 PyTorch 张量和模块状态,以及如何序列化 Python 模块以便在 C++ 中加载它们。
目录
保存和加载张量¶
torch.save()
和 torch.load()
使您能够轻松保存和加载张量
>>> t = torch.tensor([1., 2.])
>>> torch.save(t, 'tensor.pt')
>>> torch.load('tensor.pt')
tensor([1., 2.])
按照惯例,PyTorch 文件通常使用 '.pt' 或 '.pth' 扩展名。
torch.save()
和 torch.load()
默认使用 Python 的 pickle,因此您也可以将多个张量保存为元组、列表和字典等 Python 对象的一部分
>>> d = {'a': torch.tensor([1., 2.]), 'b': torch.tensor([3., 4.])}
>>> torch.save(d, 'tensor_dict.pt')
>>> torch.load('tensor_dict.pt')
{'a': tensor([1., 2.]), 'b': tensor([3., 4.])}
包含 PyTorch 张量的自定义数据结构,如果该数据结构是可 pickle 化的,也可以被保存。
保存和加载张量保留视图¶
保存张量会保留它们的视图关系
>>> numbers = torch.arange(1, 10)
>>> evens = numbers[1::2]
>>> torch.save([numbers, evens], 'tensors.pt')
>>> loaded_numbers, loaded_evens = torch.load('tensors.pt')
>>> loaded_evens *= 2
>>> loaded_numbers
tensor([ 1, 4, 3, 8, 5, 12, 7, 16, 9])
在幕后,这些张量共享相同的“存储 (storage)”。请参阅 张量视图 (Tensor Views) 以了解有关视图和存储的更多信息。
当 PyTorch 保存张量时,它会单独保存它们的存储对象 (storage objects) 和张量元数据。这是一个实现细节,将来可能会发生变化,但它通常可以节省空间,并让 PyTorch 轻松重建加载的张量之间的视图关系。例如,在上面的代码片段中,只有一个存储被写入 'tensors.pt'。
然而,在某些情况下,保存当前的存储对象可能是没有必要的,并且会创建过大的文件。在下面的代码片段中,一个远大于所保存张量的存储被写入到一个文件
>>> large = torch.arange(1, 1000)
>>> small = large[0:5]
>>> torch.save(small, 'small.pt')
>>> loaded_small = torch.load('small.pt')
>>> loaded_small.storage().size()
999
保存到 'small.pt' 文件中的,不是 small 张量中的五个值,而是它与 large 共享的存储中的 999 个值被保存和加载了。
当保存的张量包含的元素少于其存储对象中的元素时,可以通过先复制 (cloning) 张量来减小保存文件的大小。复制张量会生成一个新的张量,该张量拥有一个新的存储对象,仅包含该张量中的值
>>> large = torch.arange(1, 1000)
>>> small = large[0:5]
>>> torch.save(small.clone(), 'small.pt') # saves a clone of small
>>> loaded_small = torch.load('small.pt')
>>> loaded_small.storage().size()
5
然而,由于复制的张量是相互独立的,它们不具有原始张量之间的视图关系。如果保存小于其存储对象的张量时,文件大小和视图关系都很重要,那么必须在保存之前仔细构建新的张量,以尽量减小其存储对象的大小,但仍保留所需的视图关系。
保存和加载 torch.nn.Modules¶
另请参阅:教程:保存和加载模块
在 PyTorch 中,模块的状态通常使用“状态字典 (state dict)”进行序列化。模块的状态字典包含其所有参数和持久缓冲区 (persistent buffers)
>>> bn = torch.nn.BatchNorm1d(3, track_running_stats=True)
>>> list(bn.named_parameters())
[('weight', Parameter containing: tensor([1., 1., 1.], requires_grad=True)),
('bias', Parameter containing: tensor([0., 0., 0.], requires_grad=True))]
>>> list(bn.named_buffers())
[('running_mean', tensor([0., 0., 0.])),
('running_var', tensor([1., 1., 1.])),
('num_batches_tracked', tensor(0))]
>>> bn.state_dict()
OrderedDict([('weight', tensor([1., 1., 1.])),
('bias', tensor([0., 0., 0.])),
('running_mean', tensor([0., 0., 0.])),
('running_var', tensor([1., 1., 1.])),
('num_batches_tracked', tensor(0))])
出于兼容性原因,建议不要直接保存模块,而是只保存其状态字典。Python 模块甚至有一个函数 load_state_dict()
,用于从状态字典恢复其状态
>>> torch.save(bn.state_dict(), 'bn.pt')
>>> bn_state_dict = torch.load('bn.pt')
>>> new_bn = torch.nn.BatchNorm1d(3, track_running_stats=True)
>>> new_bn.load_state_dict(bn_state_dict)
<All keys matched successfully>
请注意,状态字典首先使用 torch.load()
从文件中加载,然后使用 load_state_dict()
恢复状态。
即使是自定义模块和包含其他模块的模块也具有状态字典,并且可以使用此模式
# A module with two linear layers
>>> class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.l0 = torch.nn.Linear(4, 2)
self.l1 = torch.nn.Linear(2, 1)
def forward(self, input):
out0 = self.l0(input)
out0_relu = torch.nn.functional.relu(out0)
return self.l1(out0_relu)
>>> m = MyModule()
>>> m.state_dict()
OrderedDict([('l0.weight', tensor([[ 0.1400, 0.4563, -0.0271, -0.4406],
[-0.3289, 0.2827, 0.4588, 0.2031]])),
('l0.bias', tensor([ 0.0300, -0.1316])),
('l1.weight', tensor([[0.6533, 0.3413]])),
('l1.bias', tensor([-0.1112]))])
>>> torch.save(m.state_dict(), 'mymodule.pt')
>>> m_state_dict = torch.load('mymodule.pt')
>>> new_m = MyModule()
>>> new_m.load_state_dict(m_state_dict)
<All keys matched successfully>
torch.save
的序列化文件格式¶
自 PyTorch 1.6.0 版本起,除非用户设置 _use_new_zipfile_serialization=False
,否则 torch.save
默认返回未压缩的 ZIP64 归档文件。
在此归档文件中,文件按如下顺序排列
checkpoint.pth
├── data.pkl
├── byteorder # added in PyTorch 2.1.0
├── data/
│ ├── 0
│ ├── 1
│ ├── 2
│ └── …
└── version
- 条目如下
data.pkl
是对传递给torch.save
的对象进行 pickle 化的结果,其中不包含对象内部的torch.Storage
对象byteorder
包含一个字符串,表示保存时的sys.byteorder
(“little” 或 “big”)data/
包含对象中的所有存储,其中每个存储都是一个单独的文件version
包含保存时的版本号,可在加载时使用
保存时,PyTorch 将确保每个文件的本地文件头填充到 64 字节的倍数偏移量,从而确保每个文件的偏移量是 64 字节对齐的。
注意
某些设备(如 XLA)上的张量被序列化为 pickled 的 numpy 数组。因此,它们的存储不会被序列化。在这种情况下,检查点中可能不存在 data/
。
带 weights_only=True
的 torch.load
¶
从 2.6 版本开始,如果未传递 pickle_module
参数,torch.load
将使用 weights_only=True
。
正如 torch.load()
文档中所讨论的,weights_only=True
将 torch.load
中使用的 unpickler 限制为仅执行普通 torch.Tensors
的 state_dicts
以及其他一些原始类型所需的函数/构建类。此外,与 pickle
模块提供的默认 Unpickler
不同,weights_only
Unpickler 在 unpickling 过程中不允许动态导入任何内容。
如上所述,在使用 torch.save
时,保存模块的 state_dict
是一个最佳实践。如果加载包含 nn.Module
的旧检查点,我们建议使用 weights_only=False
。加载包含张量子类的检查点时,很可能需要将某些函数/类添加到允许列表中,详情见下文。
如果 weights_only
Unpickler 在 pickle 文件中遇到默认情况下未添加到允许列表的函数或类,您应该会看到类似以下的可操作错误
_pickle.UnpicklingError: Weights only load failed. This file can still be loaded,
to do so you have two options, do those steps only if you trust the source of the checkpoint.
1. Re-running `torch.load` with `weights_only` set to `False` will likely succeed,
but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
2. Alternatively, to load with `weights_only=True` please check the recommended
steps in the following error message.
WeightsUnpickler error: Unsupported global: GLOBAL {__module__}.{__name__} was not an allowed global by
default. Please use `torch.serialization.add_safe_globals([{__name__}])` or the
`torch.serialization.safe_globals([{__name__}])` context manager to allowlist this global
if you trust this class/function.
请按照错误消息中的步骤,仅在您信任这些函数或类时,才将它们添加到允许列表中。
要获取检查点中所有尚未添加到允许列表的 GLOBAL(函数/类),您可以使用 torch.serialization.get_unsafe_globals_in_checkpoint()
,它将返回一个形式为 {__module__}.{__name__}
的字符串列表。如果您信任这些函数/类,您可以导入它们,并按照错误消息的指示,通过 torch.serialization.add_safe_globals()
或上下文管理器 torch.serialization.safe_globals
将它们添加到允许列表中。
要访问用户添加到允许列表的函数/类列表,您可以使用 torch.serialization.get_safe_globals()
;要清除当前列表,请参阅 torch.serialization.clear_safe_globals()
。
排除 weights_only
故障¶
获取不安全的全局变量¶
需要注意的是,torch.serialization.get_unsafe_globals_in_checkpoint()
是静态分析检查点,某些类型可能在 unpickling 过程中动态构建,因此不会被 torch.serialization.get_unsafe_globals_in_checkpoint()
报告。一个这样的例子是 numpy 中的 dtypes
。在 numpy < 1.25
中,将 torch.serialization.get_unsafe_globals_in_checkpoint()
报告的所有函数/类添加到允许列表后,您可能会看到类似以下的错误
WeightsUnpickler error: Can only build Tensor, Parameter, OrderedDict or types allowlisted via `add_safe_globals`,
but got <class 'numpy.dtype[float32]'>
这可以通过 {add_}safe_globals([type(np.dtype(np.float32))])
添加到允许列表。
在 numpy >=1.25
中,您会看到
WeightsUnpickler error: Can only build Tensor, Parameter, OrderedDict or types allowlisted via `add_safe_globals`,
but got <class 'numpy.dtypes.Float32DType'>
这可以通过 {add_}safe_globals([np.dtypes.Float32DType])
添加到允许列表。
序列化 torch.nn.Modules 并在 C++ 中加载它们¶
另请参阅:教程:在 C++ 中加载 TorchScript 模型
ScriptModules 可以序列化为 TorchScript 程序,并使用 torch.jit.load()
加载。这种序列化编码了模块的所有方法、子模块、参数和属性,并且允许序列化的程序在 C++ 中加载(即无需 Python 环境)。
torch.jit.save()
和 torch.save()
之间的区别可能不是立即可见的。torch.save()
使用 pickle 保存 Python 对象。这对于原型开发、研究和训练特别有用。torch.jit.save()
则将 ScriptModules 序列化为可以在 Python 或 C++ 中加载的格式。这在保存和加载 C++ 模块或使用 C++ 运行在 Python 中训练的模块时非常有用,这是部署 PyTorch 模型时的常见做法。
在 Python 中进行脚本化、序列化和加载模块
>>> scripted_module = torch.jit.script(MyModule())
>>> torch.jit.save(scripted_module, 'mymodule.pt')
>>> torch.jit.load('mymodule.pt')
RecursiveScriptModule( original_name=MyModule
(l0): RecursiveScriptModule(original_name=Linear)
(l1): RecursiveScriptModule(original_name=Linear) )
跟踪模块也可以使用 torch.jit.save()
保存,但需要注意的是,只序列化跟踪到的代码路径。以下示例演示了这一点
# A module with control flow
>>> class ControlFlowModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.l0 = torch.nn.Linear(4, 2)
self.l1 = torch.nn.Linear(2, 1)
def forward(self, input):
if input.dim() > 1:
return torch.tensor(0)
out0 = self.l0(input)
out0_relu = torch.nn.functional.relu(out0)
return self.l1(out0_relu)
>>> traced_module = torch.jit.trace(ControlFlowModule(), torch.randn(4))
>>> torch.jit.save(traced_module, 'controlflowmodule_traced.pt')
>>> loaded = torch.jit.load('controlflowmodule_traced.pt')
>>> loaded(torch.randn(2, 4)))
tensor([[-0.1571], [-0.3793]], grad_fn=<AddBackward0>)
>>> scripted_module = torch.jit.script(ControlFlowModule(), torch.randn(4))
>>> torch.jit.save(scripted_module, 'controlflowmodule_scripted.pt')
>>> loaded = torch.jit.load('controlflowmodule_scripted.pt')
>> loaded(torch.randn(2, 4))
tensor(0)
上面的模块有一个 if 语句,该语句未被跟踪的输入触发,因此它不是跟踪模块的一部分,也不会随之序列化。然而,脚本化的模块包含该 if 语句并随之序列化。有关脚本化和跟踪的更多信息,请参阅 TorchScript 文档。
最后,在 C++ 中加载模块
>>> torch::jit::script::Module module;
>>> module = torch::jit::load('controlflowmodule_scripted.pt');
有关如何在 C++ 中使用 PyTorch 模块的详细信息,请参阅 PyTorch C++ API 文档。
跨 PyTorch 版本保存和加载 ScriptModules¶
PyTorch 团队建议使用相同版本的 PyTorch 保存和加载模块。较旧的 PyTorch 版本可能不支持较新的模块,而较新的版本可能已删除或修改了旧的行为。这些更改在 PyTorch 的 发布说明 中有明确描述,依赖已更改功能的模块可能需要更新才能继续正常工作。在下面详细介绍的有限情况下,PyTorch 将保留序列化 ScriptModules 的历史行为,这样它们就不需要更新。
torch.div
执行整数除法¶
在 PyTorch 1.5 及更早版本中,当给定两个整数输入时,torch.div()
会执行地板除法 (floor division)
# PyTorch 1.5 (and earlier)
>>> a = torch.tensor(5)
>>> b = torch.tensor(3)
>>> a / b
tensor(1)
然而,在 PyTorch 1.7 中,torch.div()
将始终对其输入执行真除法 (true division),就像 Python 3 中的除法一样
# PyTorch 1.7
>>> a = torch.tensor(5)
>>> b = torch.tensor(3)
>>> a / b
tensor(1.6667)
torch.div()
的行为在序列化的 ScriptModules 中被保留。也就是说,即使使用较新版本的 PyTorch 加载,使用 PyTorch 1.6 之前版本序列化的 ScriptModules 在给定两个整数输入时仍将看到 torch.div()
执行地板除法。然而,在 PyTorch 1.6 及更高版本上使用 torch.div()
并序列化的 ScriptModules 不能在更早的 PyTorch 版本中加载,因为这些更早的版本无法理解新行为。
torch.full
总是推断为浮点型 dtype¶
在 PyTorch 1.5 及更早版本中,无论给定什么填充值,torch.full()
总是返回一个浮点型张量
# PyTorch 1.5 and earlier
>>> torch.full((3,), 1) # Note the integer fill value...
tensor([1., 1., 1.]) # ...but float tensor!
然而,在 PyTorch 1.7 中,torch.full()
将从填充值推断返回张量的 dtype
# PyTorch 1.7
>>> torch.full((3,), 1)
tensor([1, 1, 1])
>>> torch.full((3,), True)
tensor([True, True, True])
>>> torch.full((3,), 1.)
tensor([1., 1., 1.])
>>> torch.full((3,), 1 + 1j)
tensor([1.+1.j, 1.+1.j, 1.+1.j])
torch.full()
的行为在序列化的 ScriptModules 中被保留。也就是说,即使给定 bool 或整数填充值,使用 PyTorch 1.6 之前版本序列化的 ScriptModules 默认仍将看到 torch.full 返回浮点型张量。然而,在 PyTorch 1.6 及更高版本上使用 torch.full()
并序列化的 ScriptModules 不能在更早的 PyTorch 版本中加载,因为这些更早的版本无法理解新行为。
实用函数¶
以下实用函数与序列化相关
- torch.serialization.register_package(priority, tagger, deserializer)[源][源]¶
注册用于标记和反序列化存储对象的可调用对象,并关联优先级。标记 (tagging) 在保存时将设备与存储对象关联,而反序列化 (deserializing) 在加载时将存储对象移动到适当的设备。
tagger
和deserializer
按照其priority
给定的顺序运行,直到 tagger/deserializer 返回一个非 None 的值。要覆盖全局注册表中某个设备的反序列化行为,可以注册一个优先级高于现有 tagger 的 tagger。
此函数还可用于为新设备注册 tagger 和 deserializer。
- 参数
priority (int) – 表示与 tagger 和 deserializer 关联的优先级,值越低表示优先级越高。
tagger (Callable[[Union[Storage, TypedStorage, UntypedStorage]], Optional[str]]) – 可调用对象,接受一个存储对象并将其标记的设备以字符串形式返回,或返回 None。
deserializer (Callable[[Union[Storage, TypedStorage, UntypedStorage], str], Optional[Union[Storage, TypedStorage, UntypedStorage]]]) – 一个可调用对象,接受存储对象和设备字符串,并返回适当设备上的存储对象或 None。
- 返回
None
示例
>>> def ipu_tag(obj): >>> if obj.device.type == 'ipu': >>> return 'ipu' >>> def ipu_deserialize(obj, location): >>> if location.startswith('ipu'): >>> ipu = getattr(torch, "ipu", None) >>> assert ipu is not None, "IPU device module is not loaded" >>> assert torch.ipu.is_available(), "ipu is not available" >>> return obj.ipu(location) >>> torch.serialization.register_package(11, ipu_tag, ipu_deserialize)
- torch.serialization.get_crc32_options()[source][source]¶
获取
torch.save()
是否计算并为每个记录写入 crc32。默认为
True
。- 返回类型
- torch.serialization.set_crc32_options(compute_crc32)[source][source]¶
设置
torch.save()
是否计算并为每个记录写入 crc32。注意
将其设置为
False
可能会导致解压torch.save
输出时因 CRC32 损坏而失败或发出警告。但torch.load
将能够加载该文件。- 参数
compute_crc32 (bool) – 设置 crc32 计算标志
- torch.serialization.get_default_load_endianness()[source][source]¶
获取加载文件的回退字节顺序
如果保存的检查点中不存在字节顺序标记,则使用此字节顺序作为回退。默认情况下,它是“native”字节顺序。
- 返回
Optional[LoadEndianness]
- 返回类型
default_load_endian
- torch.serialization.set_default_load_endianness(endianness)[source][source]¶
设置加载文件的回退字节顺序
如果保存的检查点中不存在字节顺序标记,则使用此字节顺序作为回退。默认情况下,它是“native”字节顺序。
- 参数
endianness – 新的回退字节顺序
- torch.serialization.get_default_mmap_options()[source][source]¶
获取
torch.load()
并设置mmap=True
时的默认 mmap 选项。默认为
mmap.MAP_PRIVATE
。- 返回
int
- 返回类型
default_mmap_options
- torch.serialization.set_default_mmap_options(flags)[source][source]¶
上下文管理器或函数,用于设置
torch.load()
并设置mmap=True
时的默认 mmap 选项为 flags。目前,仅支持
mmap.MAP_PRIVATE
或mmap.MAP_SHARED
。如果您需要添加任何其他选项,请提交 issue。注意
此功能目前不支持 Windows。
- 参数
flags (int) –
mmap.MAP_PRIVATE
或mmap.MAP_SHARED
- torch.serialization.add_safe_globals(safe_globals)[source][source]¶
将给定的全局变量标记为对
weights_only
加载安全。例如,添加到此列表中的函数可以在反序列化时被调用,类可以被实例化并设置状态。列表中的每个项可以是函数/类本身,也可以是形式为 (函数/类, 字符串) 的元组,其中字符串是函数/类的完整路径。
在序列化格式中,每个函数都由其完整路径
{__module__}.{__qualname__}
标识。调用此 API 时,您可以提供应该与检查点中匹配的完整路径,否则将使用默认的{fn.__module__}.{fn.__qualname__}
。- 参数
safe_globals (List[Union[Callable, Tuple[Callable, str]]]) – 要标记为安全的全局变量列表
示例
>>> import tempfile >>> class MyTensor(torch.Tensor): ... pass >>> t = MyTensor(torch.randn(2, 3)) >>> with tempfile.NamedTemporaryFile() as f: ... torch.save(t, f.name) # Running `torch.load(f.name, weights_only=True)` will fail with # Unsupported global: GLOBAL __main__.MyTensor was not an allowed global by default. # Check the code and make sure MyTensor is safe to be used when loaded from an arbitrary checkpoint. ... torch.serialization.add_safe_globals([MyTensor]) ... torch.load(f.name, weights_only=True) # MyTensor([[-0.5024, -1.8152, -0.5455], # [-0.8234, 2.0500, -0.3657]])
- torch.serialization.get_unsafe_globals_in_checkpoint(f)[source][source]¶
返回一个字符串列表,包含
torch.save
对象中对weights_only
不安全的函数/类。对于给定的函数或类
f
,对应的字符串将是{f.__module__}.{f.__name__}
的形式。此函数将返回检查点中未标记为对
weights_only
安全的任何全局变量(无论是通过add_safe_globals()
或safe_globals
上下文,还是由torch
默认列入白名单)。注意
此函数将静态反汇编检查点中的 pickle 文件。这意味着在反序列化期间动态推送到栈上的任何类将不包含在输出中。
- class torch.serialization.safe_globals(safe_globals)[source][source]¶
上下文管理器,将某些全局变量添加为对
weights_only
加载安全。示例
>>> import tempfile >>> class MyTensor(torch.Tensor): ... pass >>> t = MyTensor(torch.randn(2, 3)) >>> with tempfile.NamedTemporaryFile() as f: ... torch.save(t, f.name) # Running `torch.load(f.name, weights_only=True)` will fail with # Unsupported global: GLOBAL __main__.MyTensor was not an allowed global by default. # Check the code and make sure MyTensor is safe to be used when loaded from an arbitrary checkpoint. ... with torch.serialization.safe_globals([MyTensor]): ... torch.load(f.name, weights_only=True) # MyTensor([[-0.5024, -1.8152, -0.5455], # [-0.8234, 2.0500, -0.3657]]) >>> assert torch.serialization.get_safe_globals() == []
- class torch.serialization.skip_data(materialize_fake_tensors=False)[source][source]¶
上下文管理器,用于在
torch.save
/torch.load
调用时跳过写入/读取存储字节。对于保存路径,存储仍将保存,但通常写入其字节的空间将是空闲空间。随后可以在单独的过程中填充存储字节。
对于加载路径,张量将按检查点加载,但它们的存储不会填充数据。
警告
skip_data
上下文管理器是早期原型,可能会发生变化。- 参数
materialize_fake_tensors (bool) – 是否在保存期间具体化 FakeTensors。这对于加载路径是空操作。
示例
>>> import tempfile >>> t = torch.randn(2, 3) >>> with tempfile.NamedTemporaryFile() as f: ... with torch.serialization.skip_data(): ... torch.save(t, f.name) ... torch.load(f.name, weights_only=True) tensor([[0., 0., 0.], [0., 0., 0.]])
配置¶
torch.utils.serialization.config
提供了一个全局配置,可以控制 torch.save
和 torch.load
的行为。
torch.utils.serialization.config.save
包含控制 torch.save
行为的选项。
compute_crc32
:是否计算并写入 zip 文件校验和(默认值:True
)。参见set_crc32_options()
。
use_pinned_memory_for_d2h
:对于传递给torch.save
时位于加速器上的存储,是否在torch.save
内将存储移动到 CPU 上的锁定内存或可分页内存(默认值:False
(即可分页))。
storage_alignment
:在torch.save
期间检查点中存储的对齐方式(以字节为单位)。(默认值64
)
torch.utils.serialization.config.load
包含控制 torch.load
行为的选项。
mmap
:参见torch.load()
中mmap
参数的文档。如果未显式传递给torch.load
调用,此配置将设置mmap
对于torch.load
的行为(默认值:False
)。
endianness
:参见set_default_load_endianness()
。(默认值:torch.serialization.LoadEndianness.NATIVE
)
mmap_flags
:参见set_default_mmap_options
。(默认值:MAP_PRIVATE
)
calculate_storage_offsets
:如果此配置设置为True
,则在使用torch.load(mmap=True)
时将计算存储的偏移量,而不是通过随机读取来读取。这最大程度地减少了随机读取,当文件通过网络加载时会很有帮助。(默认值:False
)