序列化语义¶
本笔记介绍如何在 Python 中保存和加载 PyTorch 张量和模块状态,以及如何序列化 Python 模块以便在 C++ 中加载。
目录
保存和加载张量¶
torch.save()
和 torch.load()
允许您轻松保存和加载张量
>>> t = torch.tensor([1., 2.])
>>> torch.save(t, 'tensor.pt')
>>> torch.load('tensor.pt')
tensor([1., 2.])
按照惯例,PyTorch 文件通常使用“.pt”或“.pth”扩展名。
torch.save()
和 torch.load()
默认使用 Python 的 pickle,因此您还可以将多个张量作为 Python 对象(如元组、列表和字典)的一部分保存。
>>> d = {'a': torch.tensor([1., 2.]), 'b': torch.tensor([3., 4.])}
>>> torch.save(d, 'tensor_dict.pt')
>>> torch.load('tensor_dict.pt')
{'a': tensor([1., 2.]), 'b': tensor([3., 4.])}
如果数据结构是可 pickle 的,则包含 PyTorch 张量的自定义数据结构也可以保存。
保存和加载张量会保留视图¶
保存张量会保留它们的视图关系。
>>> numbers = torch.arange(1, 10)
>>> evens = numbers[1::2]
>>> torch.save([numbers, evens], 'tensors.pt')
>>> loaded_numbers, loaded_evens = torch.load('tensors.pt')
>>> loaded_evens *= 2
>>> loaded_numbers
tensor([ 1, 4, 3, 8, 5, 12, 7, 16, 9])
在幕后,这些张量共享相同的“存储”。有关视图和存储的更多信息,请参阅 张量视图。
当 PyTorch 保存张量时,它会分别保存它们的存储对象和张量元数据。这是一个可能在将来更改的实现细节,但它通常可以节省空间并让 PyTorch 轻松重建加载的张量之间的视图关系。例如,在上面的代码段中,只有一个存储被写入“tensors.pt”。
但是,在某些情况下,保存当前存储对象可能是不必要的,并且会创建过大的文件。在下面的代码段中,一个比保存的张量大得多的存储被写入文件。
>>> large = torch.arange(1, 1000)
>>> small = large[0:5]
>>> torch.save(small, 'small.pt')
>>> loaded_small = torch.load('small.pt')
>>> loaded_small.storage().size()
999
而不是仅将 small 张量中的五个值保存到“small.pt”,它还保存并加载了它与 large 共享的存储中的 999 个值。
当保存的张量元素少于其存储对象时,可以通过先克隆张量来减小保存的文件大小。克隆张量会生成一个新的张量,该张量具有一个新的存储对象,其中仅包含张量中的值。
>>> large = torch.arange(1, 1000)
>>> small = large[0:5]
>>> torch.save(small.clone(), 'small.pt') # saves a clone of small
>>> loaded_small = torch.load('small.pt')
>>> loaded_small.storage().size()
5
但是,由于克隆的张量彼此独立,因此它们没有原始张量具有的任何视图关系。如果在保存小于其存储对象的张量时文件大小和视图关系都很重要,则必须注意在保存之前构造新的张量,以最大程度地减小其存储对象的大小,但仍具有所需的视图关系。
保存和加载 torch.nn.Modules¶
另请参阅:教程:保存和加载模块
在 PyTorch 中,模块的状态通常使用“状态字典”进行序列化。模块的状态字典包含其所有参数和持久缓冲区。
>>> bn = torch.nn.BatchNorm1d(3, track_running_stats=True)
>>> list(bn.named_parameters())
[('weight', Parameter containing: tensor([1., 1., 1.], requires_grad=True)),
('bias', Parameter containing: tensor([0., 0., 0.], requires_grad=True))]
>>> list(bn.named_buffers())
[('running_mean', tensor([0., 0., 0.])),
('running_var', tensor([1., 1., 1.])),
('num_batches_tracked', tensor(0))]
>>> bn.state_dict()
OrderedDict([('weight', tensor([1., 1., 1.])),
('bias', tensor([0., 0., 0.])),
('running_mean', tensor([0., 0., 0.])),
('running_var', tensor([1., 1., 1.])),
('num_batches_tracked', tensor(0))])
为了提高兼容性,建议不要直接保存模块,而是仅保存其状态字典。Python 模块甚至有一个函数,load_state_dict()
,用于从状态字典中恢复其状态。
>>> torch.save(bn.state_dict(), 'bn.pt')
>>> bn_state_dict = torch.load('bn.pt')
>>> new_bn = torch.nn.BatchNorm1d(3, track_running_stats=True)
>>> new_bn.load_state_dict(bn_state_dict)
<All keys matched successfully>
请注意,状态字典首先使用 torch.load()
从其文件中加载,然后使用 load_state_dict()
恢复状态。
即使是自定义模块和包含其他模块的模块也具有状态字典,并且可以使用此模式。
# A module with two linear layers
>>> class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.l0 = torch.nn.Linear(4, 2)
self.l1 = torch.nn.Linear(2, 1)
def forward(self, input):
out0 = self.l0(input)
out0_relu = torch.nn.functional.relu(out0)
return self.l1(out0_relu)
>>> m = MyModule()
>>> m.state_dict()
OrderedDict([('l0.weight', tensor([[ 0.1400, 0.4563, -0.0271, -0.4406],
[-0.3289, 0.2827, 0.4588, 0.2031]])),
('l0.bias', tensor([ 0.0300, -0.1316])),
('l1.weight', tensor([[0.6533, 0.3413]])),
('l1.bias', tensor([-0.1112]))])
>>> torch.save(m.state_dict(), 'mymodule.pt')
>>> m_state_dict = torch.load('mymodule.pt')
>>> new_m = MyModule()
>>> new_m.load_state_dict(m_state_dict)
<All keys matched successfully>
torch.save
的序列化文件格式¶
从 PyTorch 1.6.0 开始,torch.save
默认返回未压缩的 ZIP64 存档,除非用户设置 _use_new_zipfile_serialization=False
。
在此存档中,文件按以下顺序排列。
checkpoint.pth
├── data.pkl
├── byteorder # added in PyTorch 2.1.0
├── data/
│ ├── 0
│ ├── 1
│ ├── 2
│ └── …
└── version
- 条目如下所示。
data.pkl
是对传递给torch.save
的对象的 pickle 结果,不包括它包含的torch.Storage
对象。byteorder
包含一个字符串,其中包含保存时的sys.byteorder
(“little”或“big”)。data/
包含对象中的所有存储,其中每个存储都是一个单独的文件。version
包含保存时的版本号,可以在加载时使用。
保存时,PyTorch 将确保每个文件的本地文件头填充到 64 字节倍数的偏移量,确保每个文件的偏移量都是 64 字节对齐的。
注意
某些设备(例如 XLA)上的张量被序列化为 pickle 后的 NumPy 数组。因此,它们的存储不会被序列化。在这些情况下,data/
可能不存在于检查点中。
序列化 torch.nn.Modules 并在 C++ 中加载它们¶
另请参阅:教程:在 C++ 中加载 TorchScript 模型
ScriptModules 可以序列化为 TorchScript 程序,并使用 torch.jit.load()
加载。此序列化编码了所有模块的方法、子模块、参数和属性,并允许在 C++ 中加载序列化的程序(即无需 Python)。
torch.jit.save()
和 torch.save()
之间的区别可能并不立即清楚。torch.save()
使用 pickle 保存 Python 对象。这对于原型设计、研究和训练特别有用。torch.jit.save()
另一方面,将 ScriptModules 序列化为可以在 Python 或 C++ 中加载的格式。这在保存和加载 C++ 模块或使用 C++ 运行在 Python 中训练的模块时很有用,这是部署 PyTorch 模型的常见做法。
在 Python 中脚本化、序列化和加载模块
>>> scripted_module = torch.jit.script(MyModule())
>>> torch.jit.save(scripted_module, 'mymodule.pt')
>>> torch.jit.load('mymodule.pt')
RecursiveScriptModule( original_name=MyModule
(l0): RecursiveScriptModule(original_name=Linear)
(l1): RecursiveScriptModule(original_name=Linear) )
跟踪的模块也可以使用 torch.jit.save()
保存,但需要注意的是,只有跟踪的代码路径会被序列化。以下示例演示了这一点
# A module with control flow
>>> class ControlFlowModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.l0 = torch.nn.Linear(4, 2)
self.l1 = torch.nn.Linear(2, 1)
def forward(self, input):
if input.dim() > 1:
return torch.tensor(0)
out0 = self.l0(input)
out0_relu = torch.nn.functional.relu(out0)
return self.l1(out0_relu)
>>> traced_module = torch.jit.trace(ControlFlowModule(), torch.randn(4))
>>> torch.jit.save(traced_module, 'controlflowmodule_traced.pt')
>>> loaded = torch.jit.load('controlflowmodule_traced.pt')
>>> loaded(torch.randn(2, 4)))
tensor([[-0.1571], [-0.3793]], grad_fn=<AddBackward0>)
>>> scripted_module = torch.jit.script(ControlFlowModule(), torch.randn(4))
>>> torch.jit.save(scripted_module, 'controlflowmodule_scripted.pt')
>>> loaded = torch.jit.load('controlflowmodule_scripted.pt')
>> loaded(torch.randn(2, 4))
tensor(0)
上面的模块有一个 if 语句,它不会被跟踪的输入触发,因此它不是跟踪模块的一部分,也不会与它一起序列化。但是,脚本化的模块包含 if 语句,并与它一起序列化。有关脚本化和跟踪的更多信息,请参阅 TorchScript 文档。
最后,在 C++ 中加载模块
>>> torch::jit::script::Module module;
>>> module = torch::jit::load('controlflowmodule_scripted.pt');
有关如何在 C++ 中使用 PyTorch 模块的详细信息,请参阅 PyTorch C++ API 文档。
跨 PyTorch 版本保存和加载 ScriptModules¶
PyTorch 团队建议使用相同版本的 PyTorch 保存和加载模块。旧版本的 PyTorch 可能不支持较新的模块,而较新版本可能已删除或修改了旧的行为。这些更改在 PyTorch 的 发行说明 中有明确说明,依赖于已更改功能的模块可能需要更新才能继续正常工作。在下面详细介绍的有限情况下,PyTorch 将保留序列化 ScriptModules 的历史行为,因此它们不需要更新。
torch.div 执行整数除法¶
在 PyTorch 1.5 及更早版本中,torch.div()
在给定两个整数输入时将执行地板除法。
# PyTorch 1.5 (and earlier)
>>> a = torch.tensor(5)
>>> b = torch.tensor(3)
>>> a / b
tensor(1)
但是,在 PyTorch 1.7 中,torch.div()
将始终执行其输入的真除法,就像 Python 3 中的除法一样。
# PyTorch 1.7
>>> a = torch.tensor(5)
>>> b = torch.tensor(3)
>>> a / b
tensor(1.6667)
torch.div()
的行为在序列化 ScriptModules 中得以保留。也就是说,使用 PyTorch 1.6 之前的版本序列化的 ScriptModules 将继续看到 torch.div()
在给定两个整数输入时执行地板除法,即使使用较新版本的 PyTorch 加载也是如此。但是,使用 torch.div()
并在 PyTorch 1.6 及更高版本上序列化的 ScriptModules 无法在 PyTorch 的早期版本中加载,因为这些早期版本不理解新的行为。
torch.full 始终推断浮点类型¶
在 PyTorch 1.5 及更早版本中,torch.full()
始终返回浮点张量,而不管给定的填充值是什么。
# PyTorch 1.5 and earlier
>>> torch.full((3,), 1) # Note the integer fill value...
tensor([1., 1., 1.]) # ...but float tensor!
但是,在 PyTorch 1.7 中,torch.full()
将从填充值推断返回张量的类型。
# PyTorch 1.7
>>> torch.full((3,), 1)
tensor([1, 1, 1])
>>> torch.full((3,), True)
tensor([True, True, True])
>>> torch.full((3,), 1.)
tensor([1., 1., 1.])
>>> torch.full((3,), 1 + 1j)
tensor([1.+1.j, 1.+1.j, 1.+1.j])
torch.full()
的行为在序列化 ScriptModules 中得以保留。也就是说,使用 PyTorch 1.6 之前的版本序列化的 ScriptModules 将继续看到 torch.full 默认返回浮点张量,即使给定布尔或整数填充值也是如此。但是,使用 torch.full()
并在 PyTorch 1.6 及更高版本上序列化的 ScriptModules 无法在 PyTorch 的早期版本中加载,因为这些早期版本不理解新的行为。
实用函数¶
以下实用函数与序列化相关
- torch.serialization.register_package(priority, tagger, deserializer)[source]¶
注册用于标记和反序列化存储对象的回调函数,并具有关联的优先级。标记在保存时将设备与存储对象关联,而反序列化在加载时将存储对象移动到相应的设备。
tagger
和deserializer
按照其priority
给出的顺序运行,直到标记器/反序列化器返回的值不为 None。要覆盖全局注册表中设备的反序列化行为,可以注册一个优先级高于现有标记器的标记器。
此函数也可用于为新设备注册标记器和反序列化器。
- 参数
priority (int) – 指示与标记器和反序列化器关联的优先级,其中较低的值表示较高优先级。
tagger (Callable[[Union[Storage, TypedStorage, UntypedStorage]], Optional[str]]) – 可调用对象,接收存储对象并将其标记的设备作为字符串或 None 返回。
deserializer (Callable[[Union[Storage, TypedStorage, UntypedStorage], str], Optional[Union[Storage, TypedStorage, UntypedStorage]]]) – 可调用对象,接收存储对象和设备字符串,并返回相应设备上的存储对象或 None。
- 返回
无
示例
>>> def ipu_tag(obj): >>> if obj.device.type == 'ipu': >>> return 'ipu' >>> def ipu_deserialize(obj, location): >>> if location.startswith('ipu'): >>> ipu = getattr(torch, "ipu", None) >>> assert ipu is not None, "IPU device module is not loaded" >>> assert torch.ipu.is_available(), "ipu is not available" >>> return obj.ipu(location) >>> torch.serialization.register_package(11, ipu_tag, ipu_deserialize)
- torch.serialization.get_default_load_endianness()[source]¶
获取加载文件的回退字节序
如果保存的检查点中不存在字节序标记,则此字节序用作回退。默认情况下,它是“本机”字节序。
- 返回
Optional[LoadEndianness]
- 返回类型
default_load_endian
- torch.serialization.set_default_load_endianness(endianness)[source]¶
设置加载文件的回退字节序
如果保存的检查点中不存在字节序标记,则此字节序用作回退。默认情况下,它是“本机”字节序。
- 参数
endianness – 新的回退字节序
- torch.serialization.get_default_mmap_options()[source]¶
获取
torch.load()
使用mmap=True
时的默认 mmap 选项。默认为
mmap.MAP_PRIVATE
。- 返回
int
- 返回类型
default_mmap_options
- torch.serialization.set_default_mmap_options(flags)[source]¶
上下文管理器或函数,用于将
torch.load()
使用mmap=True
时的默认 mmap 选项设置为 flags。目前,仅支持
mmap.MAP_PRIVATE
或mmap.MAP_SHARED
。如果您需要添加其他选项,请提交 issue。注意
此功能目前不支持 Windows。
- 参数
flags (int) –
mmap.MAP_PRIVATE
或mmap.MAP_SHARED
- torch.serialization.add_safe_globals(safe_globals)[source]¶
将给定的全局变量标记为
weights_only
加载的安全变量。例如,添加到此列表中的函数可以在解封过程中被调用,类可以被实例化并设置状态。- 参数
safe_globals (List[Any]) – 要标记为安全的全局变量列表
示例
>>> import tempfile >>> class MyTensor(torch.Tensor): ... pass >>> t = MyTensor(torch.randn(2, 3)) >>> with tempfile.NamedTemporaryFile() as f: ... torch.save(t, f.name) # Running `torch.load(f.name, weights_only=True)` will fail with # Unsupported global: GLOBAL __main__.MyTensor was not an allowed global by default. # Check the code and make sure MyTensor is safe to be used when loaded from an arbitrary checkpoint. ... torch.serialization.add_safe_globals([MyTensor]) ... torch.load(f.name, weights_only=True) # MyTensor([[-0.5024, -1.8152, -0.5455], # [-0.8234, 2.0500, -0.3657]])
- class torch.serialization.safe_globals(safe_globals)[source]¶
上下文管理器,将某些全局变量标记为
weights_only
加载的安全变量。示例
>>> import tempfile >>> class MyTensor(torch.Tensor): ... pass >>> t = MyTensor(torch.randn(2, 3)) >>> with tempfile.NamedTemporaryFile() as f: ... torch.save(t, f.name) # Running `torch.load(f.name, weights_only=True)` will fail with # Unsupported global: GLOBAL __main__.MyTensor was not an allowed global by default. # Check the code and make sure MyTensor is safe to be used when loaded from an arbitrary checkpoint. ... with torch.serialization.safe_globals([MyTensor]): ... torch.load(f.name, weights_only=True) # MyTensor([[-0.5024, -1.8152, -0.5455], # [-0.8234, 2.0500, -0.3657]]) >>> assert torch.serialization.get_safe_globals() == []
- class torch.serialization.skip_data(materialize_fake_tensors=False)[source]¶
上下文管理器,跳过
torch.save
调用时的存储字节写入。存储仍然会被保存,但通常存储其字节的空间将为空白空间。存储字节可以在单独的传递中填充。
警告
skip_data
上下文管理器是一个早期原型,可能会发生变化。- 参数
materialize_fake_tensors (bool) – 是否实现 FakeTensors。
示例
>>> import tempfile >>> t = torch.randn(2, 3) >>> with tempfile.NamedTemporaryFile() as f: ... with torch.serialization.skip_data(): ... torch.save(t, f.name) ... torch.load(f.name, weights_only=True) tensor([[0., 0., 0.], [0., 0., 0.]])