快捷方式

序列化语义

本说明介绍了如何在 Python 中保存和加载 PyTorch 张量和模块状态,以及如何序列化 Python 模块以便在 C++ 中加载它们。

保存和加载张量

torch.save()torch.load() 使您能够轻松保存和加载张量

>>> t = torch.tensor([1., 2.])
>>> torch.save(t, 'tensor.pt')
>>> torch.load('tensor.pt')
tensor([1., 2.])

按照惯例,PyTorch 文件通常使用 '.pt' 或 '.pth' 扩展名。

torch.save()torch.load() 默认使用 Python 的 pickle,因此您也可以将多个张量保存为元组、列表和字典等 Python 对象的一部分

>>> d = {'a': torch.tensor([1., 2.]), 'b': torch.tensor([3., 4.])}
>>> torch.save(d, 'tensor_dict.pt')
>>> torch.load('tensor_dict.pt')
{'a': tensor([1., 2.]), 'b': tensor([3., 4.])}

包含 PyTorch 张量的自定义数据结构,如果该数据结构是可 pickle 化的,也可以被保存。

保存和加载张量保留视图

保存张量会保留它们的视图关系

>>> numbers = torch.arange(1, 10)
>>> evens = numbers[1::2]
>>> torch.save([numbers, evens], 'tensors.pt')
>>> loaded_numbers, loaded_evens = torch.load('tensors.pt')
>>> loaded_evens *= 2
>>> loaded_numbers
tensor([ 1,  4,  3,  8,  5, 12,  7, 16,  9])

在幕后,这些张量共享相同的“存储 (storage)”。请参阅 张量视图 (Tensor Views) 以了解有关视图和存储的更多信息。

当 PyTorch 保存张量时,它会单独保存它们的存储对象 (storage objects) 和张量元数据。这是一个实现细节,将来可能会发生变化,但它通常可以节省空间,并让 PyTorch 轻松重建加载的张量之间的视图关系。例如,在上面的代码片段中,只有一个存储被写入 'tensors.pt'。

然而,在某些情况下,保存当前的存储对象可能是没有必要的,并且会创建过大的文件。在下面的代码片段中,一个远大于所保存张量的存储被写入到一个文件

>>> large = torch.arange(1, 1000)
>>> small = large[0:5]
>>> torch.save(small, 'small.pt')
>>> loaded_small = torch.load('small.pt')
>>> loaded_small.storage().size()
999

保存到 'small.pt' 文件中的,不是 small 张量中的五个值,而是它与 large 共享的存储中的 999 个值被保存和加载了。

当保存的张量包含的元素少于其存储对象中的元素时,可以通过先复制 (cloning) 张量来减小保存文件的大小。复制张量会生成一个新的张量,该张量拥有一个新的存储对象,仅包含该张量中的值

>>> large = torch.arange(1, 1000)
>>> small = large[0:5]
>>> torch.save(small.clone(), 'small.pt')  # saves a clone of small
>>> loaded_small = torch.load('small.pt')
>>> loaded_small.storage().size()
5

然而,由于复制的张量是相互独立的,它们不具有原始张量之间的视图关系。如果保存小于其存储对象的张量时,文件大小和视图关系都很重要,那么必须在保存之前仔细构建新的张量,以尽量减小其存储对象的大小,但仍保留所需的视图关系。

保存和加载 torch.nn.Modules

另请参阅:教程:保存和加载模块

在 PyTorch 中,模块的状态通常使用“状态字典 (state dict)”进行序列化。模块的状态字典包含其所有参数和持久缓冲区 (persistent buffers)

>>> bn = torch.nn.BatchNorm1d(3, track_running_stats=True)
>>> list(bn.named_parameters())
[('weight', Parameter containing: tensor([1., 1., 1.], requires_grad=True)),
 ('bias', Parameter containing: tensor([0., 0., 0.], requires_grad=True))]

>>> list(bn.named_buffers())
[('running_mean', tensor([0., 0., 0.])),
 ('running_var', tensor([1., 1., 1.])),
 ('num_batches_tracked', tensor(0))]

>>> bn.state_dict()
OrderedDict([('weight', tensor([1., 1., 1.])),
             ('bias', tensor([0., 0., 0.])),
             ('running_mean', tensor([0., 0., 0.])),
             ('running_var', tensor([1., 1., 1.])),
             ('num_batches_tracked', tensor(0))])

出于兼容性原因,建议不要直接保存模块,而是只保存其状态字典。Python 模块甚至有一个函数 load_state_dict(),用于从状态字典恢复其状态

>>> torch.save(bn.state_dict(), 'bn.pt')
>>> bn_state_dict = torch.load('bn.pt')
>>> new_bn = torch.nn.BatchNorm1d(3, track_running_stats=True)
>>> new_bn.load_state_dict(bn_state_dict)
<All keys matched successfully>

请注意,状态字典首先使用 torch.load() 从文件中加载,然后使用 load_state_dict() 恢复状态。

即使是自定义模块和包含其他模块的模块也具有状态字典,并且可以使用此模式

# A module with two linear layers
>>> class MyModule(torch.nn.Module):
      def __init__(self):
        super().__init__()
        self.l0 = torch.nn.Linear(4, 2)
        self.l1 = torch.nn.Linear(2, 1)

      def forward(self, input):
        out0 = self.l0(input)
        out0_relu = torch.nn.functional.relu(out0)
        return self.l1(out0_relu)

>>> m = MyModule()
>>> m.state_dict()
OrderedDict([('l0.weight', tensor([[ 0.1400, 0.4563, -0.0271, -0.4406],
                                   [-0.3289, 0.2827, 0.4588, 0.2031]])),
             ('l0.bias', tensor([ 0.0300, -0.1316])),
             ('l1.weight', tensor([[0.6533, 0.3413]])),
             ('l1.bias', tensor([-0.1112]))])

>>> torch.save(m.state_dict(), 'mymodule.pt')
>>> m_state_dict = torch.load('mymodule.pt')
>>> new_m = MyModule()
>>> new_m.load_state_dict(m_state_dict)
<All keys matched successfully>

torch.save 的序列化文件格式

自 PyTorch 1.6.0 版本起,除非用户设置 _use_new_zipfile_serialization=False,否则 torch.save 默认返回未压缩的 ZIP64 归档文件。

在此归档文件中,文件按如下顺序排列

checkpoint.pth
├── data.pkl
├── byteorder  # added in PyTorch 2.1.0
├── data/
│   ├── 0
│   ├── 1
│   ├── 2
│   └── …
└── version
条目如下
  • data.pkl 是对传递给 torch.save 的对象进行 pickle 化的结果,其中不包含对象内部的 torch.Storage 对象

  • byteorder 包含一个字符串,表示保存时的 sys.byteorder(“little” 或 “big”)

  • data/ 包含对象中的所有存储,其中每个存储都是一个单独的文件

  • version 包含保存时的版本号,可在加载时使用

保存时,PyTorch 将确保每个文件的本地文件头填充到 64 字节的倍数偏移量,从而确保每个文件的偏移量是 64 字节对齐的。

注意

某些设备(如 XLA)上的张量被序列化为 pickled 的 numpy 数组。因此,它们的存储不会被序列化。在这种情况下,检查点中可能不存在 data/

weights_only=Truetorch.load

从 2.6 版本开始,如果未传递 pickle_module 参数,torch.load 将使用 weights_only=True

正如 torch.load() 文档中所讨论的,weights_only=Truetorch.load 中使用的 unpickler 限制为仅执行普通 torch.Tensorsstate_dicts 以及其他一些原始类型所需的函数/构建类。此外,与 pickle 模块提供的默认 Unpickler 不同,weights_only Unpickler 在 unpickling 过程中不允许动态导入任何内容。

如上所述,在使用 torch.save 时,保存模块的 state_dict 是一个最佳实践。如果加载包含 nn.Module 的旧检查点,我们建议使用 weights_only=False。加载包含张量子类的检查点时,很可能需要将某些函数/类添加到允许列表中,详情见下文。

如果 weights_only Unpickler 在 pickle 文件中遇到默认情况下未添加到允许列表的函数或类,您应该会看到类似以下的可操作错误

_pickle.UnpicklingError: Weights only load failed. This file can still be loaded,
to do so you have two options, do those steps only if you trust the source of the checkpoint.
    1. Re-running `torch.load` with `weights_only` set to `False` will likely succeed,
        but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
    2. Alternatively, to load with `weights_only=True` please check the recommended
       steps in the following error message.
       WeightsUnpickler error: Unsupported global: GLOBAL {__module__}.{__name__} was not an allowed global by
       default. Please use `torch.serialization.add_safe_globals([{__name__}])` or the
       `torch.serialization.safe_globals([{__name__}])` context manager to allowlist this global
       if you trust this class/function.

请按照错误消息中的步骤,仅在您信任这些函数或类时,才将它们添加到允许列表中。

要获取检查点中所有尚未添加到允许列表的 GLOBAL(函数/类),您可以使用 torch.serialization.get_unsafe_globals_in_checkpoint(),它将返回一个形式为 {__module__}.{__name__} 的字符串列表。如果您信任这些函数/类,您可以导入它们,并按照错误消息的指示,通过 torch.serialization.add_safe_globals() 或上下文管理器 torch.serialization.safe_globals 将它们添加到允许列表中。

要访问用户添加到允许列表的函数/类列表,您可以使用 torch.serialization.get_safe_globals();要清除当前列表,请参阅 torch.serialization.clear_safe_globals()

排除 weights_only 故障

获取不安全的全局变量

需要注意的是,torch.serialization.get_unsafe_globals_in_checkpoint() 是静态分析检查点,某些类型可能在 unpickling 过程中动态构建,因此不会被 torch.serialization.get_unsafe_globals_in_checkpoint() 报告。一个这样的例子是 numpy 中的 dtypes。在 numpy < 1.25 中,将 torch.serialization.get_unsafe_globals_in_checkpoint() 报告的所有函数/类添加到允许列表后,您可能会看到类似以下的错误

WeightsUnpickler error: Can only build Tensor, Parameter, OrderedDict or types allowlisted via `add_safe_globals`,
but got <class 'numpy.dtype[float32]'>

这可以通过 {add_}safe_globals([type(np.dtype(np.float32))]) 添加到允许列表。

numpy >=1.25 中,您会看到

WeightsUnpickler error: Can only build Tensor, Parameter, OrderedDict or types allowlisted via `add_safe_globals`,
but got <class 'numpy.dtypes.Float32DType'>

这可以通过 {add_}safe_globals([np.dtypes.Float32DType]) 添加到允许列表。

环境变量

有两个环境变量会影响 torch.load 的行为。如果您无法访问 torch.load 调用点,这些变量会很有帮助。

  • TORCH_FORCE_WEIGHTS_ONLY_LOAD=1 将覆盖所有 torch.load 调用点,使其使用 weights_only=True

  • TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD=1 将使 torch.load 调用点仅在未将 weights_only 作为参数传递时才使用 weights_only=False

序列化 torch.nn.Modules 并在 C++ 中加载它们

另请参阅:教程:在 C++ 中加载 TorchScript 模型

ScriptModules 可以序列化为 TorchScript 程序,并使用 torch.jit.load() 加载。这种序列化编码了模块的所有方法、子模块、参数和属性,并且允许序列化的程序在 C++ 中加载(即无需 Python 环境)。

torch.jit.save()torch.save() 之间的区别可能不是立即可见的。torch.save() 使用 pickle 保存 Python 对象。这对于原型开发、研究和训练特别有用。torch.jit.save() 则将 ScriptModules 序列化为可以在 Python 或 C++ 中加载的格式。这在保存和加载 C++ 模块或使用 C++ 运行在 Python 中训练的模块时非常有用,这是部署 PyTorch 模型时的常见做法。

在 Python 中进行脚本化、序列化和加载模块

>>> scripted_module = torch.jit.script(MyModule())
>>> torch.jit.save(scripted_module, 'mymodule.pt')
>>> torch.jit.load('mymodule.pt')
RecursiveScriptModule( original_name=MyModule
                      (l0): RecursiveScriptModule(original_name=Linear)
                      (l1): RecursiveScriptModule(original_name=Linear) )

跟踪模块也可以使用 torch.jit.save() 保存,但需要注意的是,只序列化跟踪到的代码路径。以下示例演示了这一点

# A module with control flow
>>> class ControlFlowModule(torch.nn.Module):
      def __init__(self):
        super().__init__()
        self.l0 = torch.nn.Linear(4, 2)
        self.l1 = torch.nn.Linear(2, 1)

      def forward(self, input):
        if input.dim() > 1:
            return torch.tensor(0)

        out0 = self.l0(input)
        out0_relu = torch.nn.functional.relu(out0)
        return self.l1(out0_relu)

>>> traced_module = torch.jit.trace(ControlFlowModule(), torch.randn(4))
>>> torch.jit.save(traced_module, 'controlflowmodule_traced.pt')
>>> loaded = torch.jit.load('controlflowmodule_traced.pt')
>>> loaded(torch.randn(2, 4)))
tensor([[-0.1571], [-0.3793]], grad_fn=<AddBackward0>)

>>> scripted_module = torch.jit.script(ControlFlowModule(), torch.randn(4))
>>> torch.jit.save(scripted_module, 'controlflowmodule_scripted.pt')
>>> loaded = torch.jit.load('controlflowmodule_scripted.pt')
>> loaded(torch.randn(2, 4))
tensor(0)

上面的模块有一个 if 语句,该语句未被跟踪的输入触发,因此它不是跟踪模块的一部分,也不会随之序列化。然而,脚本化的模块包含该 if 语句并随之序列化。有关脚本化和跟踪的更多信息,请参阅 TorchScript 文档

最后,在 C++ 中加载模块

>>> torch::jit::script::Module module;
>>> module = torch::jit::load('controlflowmodule_scripted.pt');

有关如何在 C++ 中使用 PyTorch 模块的详细信息,请参阅 PyTorch C++ API 文档

跨 PyTorch 版本保存和加载 ScriptModules

PyTorch 团队建议使用相同版本的 PyTorch 保存和加载模块。较旧的 PyTorch 版本可能不支持较新的模块,而较新的版本可能已删除或修改了旧的行为。这些更改在 PyTorch 的 发布说明 中有明确描述,依赖已更改功能的模块可能需要更新才能继续正常工作。在下面详细介绍的有限情况下,PyTorch 将保留序列化 ScriptModules 的历史行为,这样它们就不需要更新。

torch.div 执行整数除法

在 PyTorch 1.5 及更早版本中,当给定两个整数输入时,torch.div() 会执行地板除法 (floor division)

# PyTorch 1.5 (and earlier)
>>> a = torch.tensor(5)
>>> b = torch.tensor(3)
>>> a / b
tensor(1)

然而,在 PyTorch 1.7 中,torch.div() 将始终对其输入执行真除法 (true division),就像 Python 3 中的除法一样

# PyTorch 1.7
>>> a = torch.tensor(5)
>>> b = torch.tensor(3)
>>> a / b
tensor(1.6667)

torch.div() 的行为在序列化的 ScriptModules 中被保留。也就是说,即使使用较新版本的 PyTorch 加载,使用 PyTorch 1.6 之前版本序列化的 ScriptModules 在给定两个整数输入时仍将看到 torch.div() 执行地板除法。然而,在 PyTorch 1.6 及更高版本上使用 torch.div() 并序列化的 ScriptModules 不能在更早的 PyTorch 版本中加载,因为这些更早的版本无法理解新行为。

torch.full 总是推断为浮点型 dtype

在 PyTorch 1.5 及更早版本中,无论给定什么填充值,torch.full() 总是返回一个浮点型张量

# PyTorch 1.5 and earlier
>>> torch.full((3,), 1)  # Note the integer fill value...
tensor([1., 1., 1.])     # ...but float tensor!

然而,在 PyTorch 1.7 中,torch.full() 将从填充值推断返回张量的 dtype

# PyTorch 1.7
>>> torch.full((3,), 1)
tensor([1, 1, 1])

>>> torch.full((3,), True)
tensor([True, True, True])

>>> torch.full((3,), 1.)
tensor([1., 1., 1.])

>>> torch.full((3,), 1 + 1j)
tensor([1.+1.j, 1.+1.j, 1.+1.j])

torch.full() 的行为在序列化的 ScriptModules 中被保留。也就是说,即使给定 bool 或整数填充值,使用 PyTorch 1.6 之前版本序列化的 ScriptModules 默认仍将看到 torch.full 返回浮点型张量。然而,在 PyTorch 1.6 及更高版本上使用 torch.full() 并序列化的 ScriptModules 不能在更早的 PyTorch 版本中加载,因为这些更早的版本无法理解新行为。

实用函数

以下实用函数与序列化相关

torch.serialization.register_package(priority, tagger, deserializer)[源][源]

注册用于标记和反序列化存储对象的可调用对象,并关联优先级。标记 (tagging) 在保存时将设备与存储对象关联,而反序列化 (deserializing) 在加载时将存储对象移动到适当的设备。taggerdeserializer 按照其 priority 给定的顺序运行,直到 tagger/deserializer 返回一个非 None 的值。

要覆盖全局注册表中某个设备的反序列化行为,可以注册一个优先级高于现有 tagger 的 tagger。

此函数还可用于为新设备注册 tagger 和 deserializer。

参数
返回

None

示例

>>> def ipu_tag(obj):
>>>     if obj.device.type == 'ipu':
>>>         return 'ipu'
>>> def ipu_deserialize(obj, location):
>>>     if location.startswith('ipu'):
>>>         ipu = getattr(torch, "ipu", None)
>>>         assert ipu is not None, "IPU device module is not loaded"
>>>         assert torch.ipu.is_available(), "ipu is not available"
>>>         return obj.ipu(location)
>>> torch.serialization.register_package(11, ipu_tag, ipu_deserialize)
torch.serialization.get_crc32_options()[source][source]

获取 torch.save() 是否计算并为每个记录写入 crc32。

默认为 True

返回类型

bool

torch.serialization.set_crc32_options(compute_crc32)[source][source]

设置 torch.save() 是否计算并为每个记录写入 crc32。

注意

将其设置为 False 可能会导致解压 torch.save 输出时因 CRC32 损坏而失败或发出警告。但 torch.load 将能够加载该文件。

参数

compute_crc32 (bool) – 设置 crc32 计算标志

torch.serialization.get_default_load_endianness()[source][source]

获取加载文件的回退字节顺序

如果保存的检查点中不存在字节顺序标记,则使用此字节顺序作为回退。默认情况下,它是“native”字节顺序。

返回

Optional[LoadEndianness]

返回类型

default_load_endian

torch.serialization.set_default_load_endianness(endianness)[source][source]

设置加载文件的回退字节顺序

如果保存的检查点中不存在字节顺序标记,则使用此字节顺序作为回退。默认情况下,它是“native”字节顺序。

参数

endianness – 新的回退字节顺序

torch.serialization.get_default_mmap_options()[source][source]

获取 torch.load() 并设置 mmap=True 时的默认 mmap 选项。

默认为 mmap.MAP_PRIVATE

返回

int

返回类型

default_mmap_options

torch.serialization.set_default_mmap_options(flags)[source][source]

上下文管理器或函数,用于设置 torch.load() 并设置 mmap=True 时的默认 mmap 选项为 flags。

目前,仅支持 mmap.MAP_PRIVATEmmap.MAP_SHARED。如果您需要添加任何其他选项,请提交 issue。

注意

此功能目前不支持 Windows。

参数

flags (int) – mmap.MAP_PRIVATEmmap.MAP_SHARED

torch.serialization.add_safe_globals(safe_globals)[source][source]

将给定的全局变量标记为对 weights_only 加载安全。例如,添加到此列表中的函数可以在反序列化时被调用,类可以被实例化并设置状态。

列表中的每个项可以是函数/类本身,也可以是形式为 (函数/类, 字符串) 的元组,其中字符串是函数/类的完整路径。

在序列化格式中,每个函数都由其完整路径 {__module__}.{__qualname__} 标识。调用此 API 时,您可以提供应该与检查点中匹配的完整路径,否则将使用默认的 {fn.__module__}.{fn.__qualname__}

参数

safe_globals (List[Union[Callable, Tuple[Callable, str]]]) – 要标记为安全的全局变量列表

示例

>>> import tempfile
>>> class MyTensor(torch.Tensor):
...     pass
>>> t = MyTensor(torch.randn(2, 3))
>>> with tempfile.NamedTemporaryFile() as f:
...     torch.save(t, f.name)
# Running `torch.load(f.name, weights_only=True)` will fail with
# Unsupported global: GLOBAL __main__.MyTensor was not an allowed global by default.
# Check the code and make sure MyTensor is safe to be used when loaded from an arbitrary checkpoint.
...     torch.serialization.add_safe_globals([MyTensor])
...     torch.load(f.name, weights_only=True)
# MyTensor([[-0.5024, -1.8152, -0.5455],
#          [-0.8234,  2.0500, -0.3657]])
torch.serialization.clear_safe_globals()[source][source]

清除对 weights_only 加载安全的全局变量列表。

torch.serialization.get_safe_globals()[source][source]

返回用户添加的对 weights_only 加载安全的全局变量列表。

返回类型

list[Union[Callable, tuple[Callable, str]]]

torch.serialization.get_unsafe_globals_in_checkpoint(f)[source][source]

返回一个字符串列表,包含 torch.save 对象中对 weights_only 不安全的函数/类。

对于给定的函数或类 f,对应的字符串将是 {f.__module__}.{f.__name__} 的形式。

此函数将返回检查点中未标记为对 weights_only 安全的任何全局变量(无论是通过 add_safe_globals()safe_globals 上下文,还是由 torch 默认列入白名单)。

注意

此函数将静态反汇编检查点中的 pickle 文件。这意味着在反序列化期间动态推送到栈上的任何类将不包含在输出中。

参数

f (Union[str, PathLike[str], IO[bytes]]) – 文件类对象或包含通过 torch.save 保存的检查点对象的字符串

返回

检查点中未列入 weights_only 白名单的 pickle 全局变量字符串列表。

返回类型

list[str]

class torch.serialization.safe_globals(safe_globals)[source][source]

上下文管理器,将某些全局变量添加为对 weights_only 加载安全。

参数

safe_globals (list[Union[Callable, tuple[Callable, str]]]) – 用于 weights_only 加载的全局变量列表。

示例

>>> import tempfile
>>> class MyTensor(torch.Tensor):
...     pass
>>> t = MyTensor(torch.randn(2, 3))
>>> with tempfile.NamedTemporaryFile() as f:
...     torch.save(t, f.name)
# Running `torch.load(f.name, weights_only=True)` will fail with
# Unsupported global: GLOBAL __main__.MyTensor was not an allowed global by default.
# Check the code and make sure MyTensor is safe to be used when loaded from an arbitrary checkpoint.
...     with torch.serialization.safe_globals([MyTensor]):
...         torch.load(f.name, weights_only=True)
# MyTensor([[-0.5024, -1.8152, -0.5455],
#          [-0.8234,  2.0500, -0.3657]])
>>> assert torch.serialization.get_safe_globals() == []
class torch.serialization.skip_data(materialize_fake_tensors=False)[source][source]

上下文管理器,用于在 torch.save / torch.load 调用时跳过写入/读取存储字节。

对于保存路径,存储仍将保存,但通常写入其字节的空间将是空闲空间。随后可以在单独的过程中填充存储字节。

对于加载路径,张量将按检查点加载,但它们的存储不会填充数据。

警告

skip_data 上下文管理器是早期原型,可能会发生变化。

参数

materialize_fake_tensors (bool) – 是否在保存期间具体化 FakeTensors。这对于加载路径是空操作。

示例

>>> import tempfile
>>> t = torch.randn(2, 3)
>>> with tempfile.NamedTemporaryFile() as f:
...     with torch.serialization.skip_data():
...         torch.save(t, f.name)
...     torch.load(f.name, weights_only=True)
tensor([[0., 0., 0.],
        [0., 0., 0.]])

配置

torch.utils.serialization.config 提供了一个全局配置,可以控制 torch.savetorch.load 的行为。

torch.utils.serialization.config.save 包含控制 torch.save 行为的选项。

  • compute_crc32:是否计算并写入 zip 文件校验和(默认值:True)。参见 set_crc32_options()

  • use_pinned_memory_for_d2h:对于传递给 torch.save 时位于加速器上的存储,是否在 torch.save 内将存储移动到 CPU 上的锁定内存或可分页内存(默认值:False(即可分页))。

  • storage_alignment:在 torch.save 期间检查点中存储的对齐方式(以字节为单位)。(默认值 64

torch.utils.serialization.config.load 包含控制 torch.load 行为的选项。

  • mmap:参见 torch.load()mmap 参数的文档。如果未显式传递给 torch.load 调用,此配置将设置 mmap 对于 torch.load 的行为(默认值:False)。

  • endianness:参见 set_default_load_endianness()。(默认值:torch.serialization.LoadEndianness.NATIVE

  • mmap_flags:参见 set_default_mmap_options。(默认值:MAP_PRIVATE

  • calculate_storage_offsets:如果此配置设置为 True,则在使用 torch.load(mmap=True) 时将计算存储的偏移量,而不是通过随机读取来读取。这最大程度地减少了随机读取,当文件通过网络加载时会很有帮助。(默认值:False

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得问题解答

查看资源