序列化语义¶
本说明描述了如何在 Python 中保存和加载 PyTorch 张量和模块状态,以及如何序列化 Python 模块以便可以在 C++ 中加载它们。
目录
保存和加载张量¶
torch.save()
和 torch.load()
让您可以轻松保存和加载张量
>>> t = torch.tensor([1., 2.])
>>> torch.save(t, 'tensor.pt')
>>> torch.load('tensor.pt')
tensor([1., 2.])
按照惯例,PyTorch 文件通常使用“.pt”或“.pth”扩展名编写。
torch.save()
和 torch.load()
默认使用 Python 的 pickle,因此您还可以将多个张量保存为 Python 对象(如元组、列表和字典)的一部分
>>> d = {'a': torch.tensor([1., 2.]), 'b': torch.tensor([3., 4.])}
>>> torch.save(d, 'tensor_dict.pt')
>>> torch.load('tensor_dict.pt')
{'a': tensor([1., 2.]), 'b': tensor([3., 4.])}
如果包含 PyTorch 张量的自定义数据结构是可 pickle 的,则也可以保存。
保存和加载张量会保留视图¶
保存张量会保留其视图关系
>>> numbers = torch.arange(1, 10)
>>> evens = numbers[1::2]
>>> torch.save([numbers, evens], 'tensors.pt')
>>> loaded_numbers, loaded_evens = torch.load('tensors.pt')
>>> loaded_evens *= 2
>>> loaded_numbers
tensor([ 1, 4, 3, 8, 5, 12, 7, 16, 9])
在幕后,这些张量共享相同的“存储”。有关视图和存储的更多信息,请参阅 张量视图。
当 PyTorch 保存张量时,它会分别保存其存储对象和张量元数据。这是一个实现细节,将来可能会更改,但这通常可以节省空间,并让 PyTorch 轻松重建加载的张量之间的视图关系。例如,在上面的代码片段中,只有一个存储被写入“tensors.pt”。
但是,在某些情况下,保存当前存储对象可能是不必要的,并且会创建过大的文件。在以下代码片段中,一个比保存的张量大得多的存储被写入文件
>>> large = torch.arange(1, 1000)
>>> small = large[0:5]
>>> torch.save(small, 'small.pt')
>>> loaded_small = torch.load('small.pt')
>>> loaded_small.storage().size()
999
不是仅将 small 张量中的五个值保存到“small.pt”中,而是保存并加载了与其与 large 共享的存储中的 999 个值。
当保存元素少于其存储对象的张量时,可以通过首先克隆张量来减小保存文件的大小。克隆张量会生成一个新张量,该张量具有仅包含张量中值的新存储对象
>>> large = torch.arange(1, 1000)
>>> small = large[0:5]
>>> torch.save(small.clone(), 'small.pt') # saves a clone of small
>>> loaded_small = torch.load('small.pt')
>>> loaded_small.storage().size()
5
但是,由于克隆的张量彼此独立,因此它们不具有原始张量所具有的任何视图关系。如果文件大小和视图关系在保存小于其存储对象的张量时都很重要,则必须注意构造新的张量,以最大限度地减小其存储对象的大小,但仍具有所需的视图关系,然后再保存。
保存和加载 torch.nn.Modules¶
另请参阅:教程:保存和加载模块
在 PyTorch 中,模块的状态通常使用“状态字典”进行序列化。模块的状态字典包含其所有参数和持久缓冲区
>>> bn = torch.nn.BatchNorm1d(3, track_running_stats=True)
>>> list(bn.named_parameters())
[('weight', Parameter containing: tensor([1., 1., 1.], requires_grad=True)),
('bias', Parameter containing: tensor([0., 0., 0.], requires_grad=True))]
>>> list(bn.named_buffers())
[('running_mean', tensor([0., 0., 0.])),
('running_var', tensor([1., 1., 1.])),
('num_batches_tracked', tensor(0))]
>>> bn.state_dict()
OrderedDict([('weight', tensor([1., 1., 1.])),
('bias', tensor([0., 0., 0.])),
('running_mean', tensor([0., 0., 0.])),
('running_var', tensor([1., 1., 1.])),
('num_batches_tracked', tensor(0))])
建议不要直接保存模块,而是仅保存其状态字典,以实现兼容性。Python 模块甚至有一个函数 load_state_dict()
,用于从状态字典恢复其状态
>>> torch.save(bn.state_dict(), 'bn.pt')
>>> bn_state_dict = torch.load('bn.pt')
>>> new_bn = torch.nn.BatchNorm1d(3, track_running_stats=True)
>>> new_bn.load_state_dict(bn_state_dict)
<All keys matched successfully>
请注意,状态字典首先使用 torch.load()
从其文件中加载,然后使用 load_state_dict()
恢复状态。
即使是自定义模块和包含其他模块的模块也具有状态字典,并且可以使用此模式
# A module with two linear layers
>>> class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.l0 = torch.nn.Linear(4, 2)
self.l1 = torch.nn.Linear(2, 1)
def forward(self, input):
out0 = self.l0(input)
out0_relu = torch.nn.functional.relu(out0)
return self.l1(out0_relu)
>>> m = MyModule()
>>> m.state_dict()
OrderedDict([('l0.weight', tensor([[ 0.1400, 0.4563, -0.0271, -0.4406],
[-0.3289, 0.2827, 0.4588, 0.2031]])),
('l0.bias', tensor([ 0.0300, -0.1316])),
('l1.weight', tensor([[0.6533, 0.3413]])),
('l1.bias', tensor([-0.1112]))])
>>> torch.save(m.state_dict(), 'mymodule.pt')
>>> m_state_dict = torch.load('mymodule.pt')
>>> new_m = MyModule()
>>> new_m.load_state_dict(m_state_dict)
<All keys matched successfully>
torch.save
的序列化文件格式¶
自 PyTorch 1.6.0 起,除非用户设置 _use_new_zipfile_serialization=False
,否则 torch.save
默认返回未压缩的 ZIP64 存档。
在此存档中,文件按以下顺序排列
checkpoint.pth
├── data.pkl
├── byteorder # added in PyTorch 2.1.0
├── data/
│ ├── 0
│ ├── 1
│ ├── 2
│ └── …
└── version
- 条目如下
data.pkl
是对传递给torch.save
的对象(不包括它包含的torch.Storage
对象)进行 pickle 化的结果byteorder
包含一个字符串,其中包含保存时的sys.byteorder
(“little”或“big”)data/
包含对象中的所有存储,其中每个存储都是一个单独的文件version
包含保存时的版本号,该版本号可以在加载时使用
保存时,PyTorch 将确保每个文件的本地文件头都填充到偏移量,该偏移量是 64 字节的倍数,从而确保每个文件的偏移量都与 64 字节对齐。
注意
某些设备(如 XLA)上的张量被序列化为 pickle 化的 numpy 数组。因此,它们的存储不会被序列化。在这些情况下,data/
可能在检查点中不存在。
使用 weights_only=True
的 torch.load
¶
从版本 2.6 开始,如果未传递 pickle_module
参数,torch.load
将使用 weights_only=True
。
如 torch.load()
的文档中所述,weights_only=True
将 torch.load
中使用的 unpickler 限制为仅执行 state_dicts
的普通 torch.Tensors
以及其他一些原始类型所需的功能/构建类。此外,与 pickle
模块提供的默认 Unpickler
不同,weights_only
Unpickler 不允许在 unpickling 期间动态导入任何内容。
如上所述,当使用 torch.save
时,保存模块的 state_dict
是一种最佳实践。如果加载包含 nn.Module
的旧检查点,我们建议使用 weights_only=False
。当加载包含张量子类的检查点时,可能需要将函数/类列入允许列表,有关更多详细信息,请参见下文。
如果 weights_only
Unpickler 遇到 pickle 文件中默认未列入允许列表的函数或类,您应该会看到类似于以下内容的可操作错误
_pickle.UnpicklingError: Weights only load failed. This file can still be loaded,
to do so you have two options, do those steps only if you trust the source of the checkpoint.
1. Re-running `torch.load` with `weights_only` set to `False` will likely succeed,
but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
2. Alternatively, to load with `weights_only=True` please check the recommended
steps in the following error message.
WeightsUnpickler error: Unsupported global: GLOBAL {__module__}.{__name__} was not an allowed global by
default. Please use `torch.serialization.add_safe_globals([{__name__}])` or the
`torch.serialization.safe_globals([{__name__}])` context manager to allowlist this global
if you trust this class/function.
请按照错误消息中的步骤操作,并且仅在您信任这些函数或类的情况下才将其列入允许列表。
要获取检查点中尚未列入允许列表的所有 GLOBAL(函数/类),您可以使用 torch.serialization.get_unsafe_globals_in_checkpoint()
,它将返回 {__module__}.{__name__}
形式的字符串列表。如果您信任这些函数/类,则可以导入它们,并根据错误消息通过 torch.serialization.add_safe_globals()
或上下文管理器 torch.serialization.safe_globals
将它们列入允许列表。
要访问用户允许列表中的函数/类列表,您可以使用 torch.serialization.get_safe_globals()
,要清除当前列表,请参阅 torch.serialization.clear_safe_globals()
。
故障排除 weights_only
¶
获取不安全全局变量¶
一个需要注意的地方是,torch.serialization.get_unsafe_globals_in_checkpoint()
静态分析检查点,某些类型可能会在 unpickling 过程中动态构建,因此 torch.serialization.get_unsafe_globals_in_checkpoint()
不会报告这些类型。numpy 中的 dtypes
就是这样一个示例。在 numpy < 1.25
中,在将 torch.serialization.get_unsafe_globals_in_checkpoint()
报告的所有函数/类列入允许列表后,您可能会看到如下错误
WeightsUnpickler error: Can only build Tensor, Parameter, OrderedDict or types allowlisted via `add_safe_globals`,
but got <class 'numpy.dtype[float32]'>
可以通过 {add_}safe_globals([type(np.dtype(np.float32))])
将其列入允许列表。
在 numpy >=1.25
中,您会看到
WeightsUnpickler error: Can only build Tensor, Parameter, OrderedDict or types allowlisted via `add_safe_globals`,
but got <class 'numpy.dtypes.Float32DType'>
可以通过 {add_}safe_globals([np.dtypes.Float32DType])
将其列入允许列表。
序列化 torch.nn.Modules 并在 C++ 中加载它们¶
另请参阅:教程:在 C++ 中加载 TorchScript 模型
ScriptModules 可以序列化为 TorchScript 程序,并使用 torch.jit.load()
加载。此序列化编码所有模块的方法、子模块、参数和属性,并允许在 C++ 中加载序列化程序(即,无需 Python)。
torch.jit.save()
和 torch.save()
之间的区别可能不是很明显。torch.save()
使用 pickle 保存 Python 对象。这对于原型设计、研究和训练特别有用。torch.jit.save()
另一方面,将 ScriptModules 序列化为可以在 Python 或 C++ 中加载的格式。当保存和加载 C++ 模块或使用 C++ 运行在 Python 中训练的模块时,这非常有用,这是部署 PyTorch 模型时的常见做法。
要在 Python 中编写脚本、序列化和加载模块
>>> scripted_module = torch.jit.script(MyModule())
>>> torch.jit.save(scripted_module, 'mymodule.pt')
>>> torch.jit.load('mymodule.pt')
RecursiveScriptModule( original_name=MyModule
(l0): RecursiveScriptModule(original_name=Linear)
(l1): RecursiveScriptModule(original_name=Linear) )
跟踪模块也可以使用 torch.jit.save()
保存,但需要注意的是,仅序列化跟踪的代码路径。以下示例演示了这一点
# A module with control flow
>>> class ControlFlowModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.l0 = torch.nn.Linear(4, 2)
self.l1 = torch.nn.Linear(2, 1)
def forward(self, input):
if input.dim() > 1:
return torch.tensor(0)
out0 = self.l0(input)
out0_relu = torch.nn.functional.relu(out0)
return self.l1(out0_relu)
>>> traced_module = torch.jit.trace(ControlFlowModule(), torch.randn(4))
>>> torch.jit.save(traced_module, 'controlflowmodule_traced.pt')
>>> loaded = torch.jit.load('controlflowmodule_traced.pt')
>>> loaded(torch.randn(2, 4)))
tensor([[-0.1571], [-0.3793]], grad_fn=<AddBackward0>)
>>> scripted_module = torch.jit.script(ControlFlowModule(), torch.randn(4))
>>> torch.jit.save(scripted_module, 'controlflowmodule_scripted.pt')
>>> loaded = torch.jit.load('controlflowmodule_scripted.pt')
>> loaded(torch.randn(2, 4))
tensor(0)
上面的模块有一个 if 语句,该语句未被跟踪输入触发,因此不属于跟踪模块,也不会随其一起序列化。但是,脚本模块包含 if 语句并随其一起序列化。有关脚本和跟踪的更多信息,请参阅 TorchScript 文档。
最后,要在 C++ 中加载模块
>>> torch::jit::script::Module module;
>>> module = torch::jit::load('controlflowmodule_scripted.pt');
有关如何在 C++ 中使用 PyTorch 模块的详细信息,请参阅 PyTorch C++ API 文档。
跨 PyTorch 版本保存和加载 ScriptModules¶
PyTorch 团队建议使用相同版本的 PyTorch 保存和加载模块。旧版本的 PyTorch 可能不支持较新的模块,而较新的版本可能已删除或修改了较旧的行为。这些更改在 PyTorch 的 发行说明 中明确描述,并且依赖于已更改的功能的模块可能需要更新才能继续正常工作。在以下详述的有限情况下,PyTorch 将保留序列化 ScriptModules 的历史行为,以便它们不需要更新。
torch.div 执行整数除法¶
在 PyTorch 1.5 及更早版本中,当给定两个整数输入时,torch.div()
将执行向下取整除法
# PyTorch 1.5 (and earlier)
>>> a = torch.tensor(5)
>>> b = torch.tensor(3)
>>> a / b
tensor(1)
但是在 PyTorch 1.7 中,torch.div()
将始终对其输入执行真除法,就像 Python 3 中的除法一样
# PyTorch 1.7
>>> a = torch.tensor(5)
>>> b = torch.tensor(3)
>>> a / b
tensor(1.6667)
torch.div()
的行为在序列化的 ScriptModules 中得到保留。也就是说,使用 1.6 之前的 PyTorch 版本序列化的 ScriptModules 将继续看到当给定两个整数输入时,torch.div()
执行向下取整除法,即使使用较新版本的 PyTorch 加载也是如此。但是,在 PyTorch 1.6 及更高版本上使用 torch.div()
并序列化的 ScriptModules 无法在早期版本的 PyTorch 中加载,因为这些早期版本不理解新行为。
torch.full 始终推断浮点 dtype¶
在 PyTorch 1.5 及更早版本中,torch.full()
始终返回浮点张量,而不管给定的填充值是什么
# PyTorch 1.5 and earlier
>>> torch.full((3,), 1) # Note the integer fill value...
tensor([1., 1., 1.]) # ...but float tensor!
但是在 PyTorch 1.7 中,torch.full()
将从填充值推断返回张量的 dtype
# PyTorch 1.7
>>> torch.full((3,), 1)
tensor([1, 1, 1])
>>> torch.full((3,), True)
tensor([True, True, True])
>>> torch.full((3,), 1.)
tensor([1., 1., 1.])
>>> torch.full((3,), 1 + 1j)
tensor([1.+1.j, 1.+1.j, 1.+1.j])
torch.full()
的行为在序列化的 ScriptModules 中得到保留。也就是说,使用 1.6 之前的 PyTorch 版本序列化的 ScriptModules 将继续看到 torch.full 默认返回浮点张量,即使给定布尔值或整数填充值也是如此。但是,在 PyTorch 1.6 及更高版本上使用 torch.full()
并序列化的 ScriptModules 无法在早期版本的 PyTorch 中加载,因为这些早期版本不理解新行为。
实用函数¶
以下实用函数与序列化相关
- torch.serialization.register_package(priority, tagger, deserializer)[源代码][源代码]¶
注册可调用对象,用于标记和反序列化具有关联优先级的存储对象。标记在保存时将设备与存储对象关联,而反序列化在加载时将存储对象移动到合适的设备。
tagger
和deserializer
按照其priority
给定的顺序运行,直到标记器/反序列化器返回的值不是 None。要覆盖全局注册表中设备的反序列化行为,可以注册一个优先级高于现有标记器的标记器。
此函数还可以用于为新设备注册标记器和反序列化器。
- 参数
priority (int) – 指示与标记器和反序列化器关联的优先级,其中值越低表示优先级越高。
tagger (Callable[[Union[Storage, TypedStorage, UntypedStorage]], Optional[str]]) – 接收存储对象并返回其标记设备作为字符串或 None 的可调用对象。
deserializer (Callable[[Union[Storage, TypedStorage, UntypedStorage], str], Optional[Union[Storage, TypedStorage, UntypedStorage]]]) – 接收存储对象和设备字符串,并返回适当设备上的存储对象或 None 的可调用对象。
- 返回
None
示例
>>> def ipu_tag(obj): >>> if obj.device.type == 'ipu': >>> return 'ipu' >>> def ipu_deserialize(obj, location): >>> if location.startswith('ipu'): >>> ipu = getattr(torch, "ipu", None) >>> assert ipu is not None, "IPU device module is not loaded" >>> assert torch.ipu.is_available(), "ipu is not available" >>> return obj.ipu(location) >>> torch.serialization.register_package(11, ipu_tag, ipu_deserialize)
- torch.serialization.get_crc32_options()[source][source]¶
获取
torch.save()
是否为每个记录计算并写入 crc32 校验码。默认为
True
。- 返回类型
- torch.serialization.set_crc32_options(compute_crc32)[source][source]¶
设置
torch.save()
是否为每个记录计算并写入 crc32 校验码。注意
将其设置为
False
可能会导致解压缩torch.save
输出失败或因 CRC32 损坏而发出警告。但是torch.load
将能够加载该文件。- 参数
compute_crc32 (bool) – 设置 crc32 计算标志
- torch.serialization.get_default_load_endianness()[source][source]¶
获取加载文件的回退字节顺序
如果在保存的检查点中不存在字节顺序标记,则使用此字节顺序作为回退。默认情况下,它是“native”字节顺序。
- 返回
Optional[LoadEndianness]
- 返回类型
default_load_endian
- torch.serialization.set_default_load_endianness(endianness)[source][source]¶
设置加载文件的回退字节顺序
如果在保存的检查点中不存在字节顺序标记,则使用此字节顺序作为回退。默认情况下,它是“native”字节顺序。
- 参数
endianness – 新的回退字节顺序
- torch.serialization.get_default_mmap_options()[source][source]¶
获取
torch.load()
与mmap=True
一起使用的默认 mmap 选项。默认为
mmap.MAP_PRIVATE
。- 返回
int
- 返回类型
default_mmap_options
- torch.serialization.set_default_mmap_options(flags)[source][source]¶
上下文管理器或函数,用于将
torch.load()
与mmap=True
一起使用的默认 mmap 选项设置为 flags。目前,仅支持
mmap.MAP_PRIVATE
或mmap.MAP_SHARED
。如果您需要在此处添加任何其他选项,请打开一个 issue。注意
Windows 尚不支持此功能。
- 参数
flags (int) –
mmap.MAP_PRIVATE
或mmap.MAP_SHARED
- torch.serialization.add_safe_globals(safe_globals)[source][source]¶
将给定的全局变量标记为
weights_only
加载的安全项。例如,添加到此列表的函数可以在 unpickling 期间调用,类可以被实例化并设置状态。列表中的每个项目可以是函数/类,也可以是(函数/类,字符串)形式的元组,其中字符串是函数/类的完整路径。
在序列化格式中,每个函数都使用其完整路径
{__module__}.{__name__}
进行标识。调用此 API 时,您可以提供此完整路径,该路径应与检查点中的路径匹配,否则将使用默认的{fn.__module__}.{fn.__name__}
。- 参数
safe_globals (List[Union[Callable, Tuple[Callable, str]]]) – 要标记为安全的全局变量列表
示例
>>> import tempfile >>> class MyTensor(torch.Tensor): ... pass >>> t = MyTensor(torch.randn(2, 3)) >>> with tempfile.NamedTemporaryFile() as f: ... torch.save(t, f.name) # Running `torch.load(f.name, weights_only=True)` will fail with # Unsupported global: GLOBAL __main__.MyTensor was not an allowed global by default. # Check the code and make sure MyTensor is safe to be used when loaded from an arbitrary checkpoint. ... torch.serialization.add_safe_globals([MyTensor]) ... torch.load(f.name, weights_only=True) # MyTensor([[-0.5024, -1.8152, -0.5455], # [-0.8234, 2.0500, -0.3657]])
- torch.serialization.get_unsafe_globals_in_checkpoint(f)[source][source]¶
返回
torch.save
对象中对于weights_only
不安全的函数/类的字符串列表。对于给定的函数或类
f
,相应的字符串将采用{f.__module__}.{f.__name__}
的形式。此函数将返回检查点中未标记为对于
weights_only
安全的任何 GLOBAL(通过add_safe_globals()
或safe_globals
上下文或默认情况下由torch
允许列表)。注意
此函数将静态反汇编检查点中的 pickle 文件。这意味着在 unpickling 期间动态推送到堆栈的任何类都不会包含在输出中。
- class torch.serialization.safe_globals(safe_globals)[source][source]¶
上下文管理器,用于将某些全局变量添加为对于
weights_only
加载的安全项。示例
>>> import tempfile >>> class MyTensor(torch.Tensor): ... pass >>> t = MyTensor(torch.randn(2, 3)) >>> with tempfile.NamedTemporaryFile() as f: ... torch.save(t, f.name) # Running `torch.load(f.name, weights_only=True)` will fail with # Unsupported global: GLOBAL __main__.MyTensor was not an allowed global by default. # Check the code and make sure MyTensor is safe to be used when loaded from an arbitrary checkpoint. ... with torch.serialization.safe_globals([MyTensor]): ... torch.load(f.name, weights_only=True) # MyTensor([[-0.5024, -1.8152, -0.5455], # [-0.8234, 2.0500, -0.3657]]) >>> assert torch.serialization.get_safe_globals() == []
- class torch.serialization.skip_data(materialize_fake_tensors=False)[source][source]¶
上下文管理器,用于跳过
torch.save
调用的存储字节写入。存储仍然会被保存,但是通常写入其字节的空间将为空白空间。然后可以在单独的通道中填充存储字节。
警告
skip_data
上下文管理器是一个早期原型,可能会发生更改。- 参数
materialize_fake_tensors (bool) – 是否实现 FakeTensor。
示例
>>> import tempfile >>> t = torch.randn(2, 3) >>> with tempfile.NamedTemporaryFile() as f: ... with torch.serialization.skip_data(): ... torch.save(t, f.name) ... torch.load(f.name, weights_only=True) tensor([[0., 0., 0.], [0., 0., 0.]])