torch.load¶
- torch.load(f, map_location=None, pickle_module=pickle, *, weights_only=False, mmap=None, **pickle_load_args)[源代码]¶
从文件中加载使用
torch.save()
保存的对象。torch.load()
使用 Python 的反序列化功能,但对作为张量基础的存储进行特殊处理。它们首先在 CPU 上反序列化,然后移动到它们保存到的设备。如果失败(例如,因为运行时系统不具备某些设备),则会引发异常。但是,可以使用map_location
参数将存储动态重新映射到另一组设备。如果
map_location
是一个可调用对象,它将为每个序列化的存储调用一次,并带有两个参数:存储和位置。存储参数将是存储的初始反序列化结果,驻留在 CPU 上。每个序列化的存储都与一个位置标记相关联,该标记标识其保存到的设备,此标记是传递给map_location
的第二个参数。'cpu'
是 CPU 张量的内置位置标记,'cuda:device_id'
(例如'cuda:2'
)是 CUDA 张量的内置位置标记。map_location
应返回None
或存储。如果map_location
返回存储,它将用作最终反序列化的对象,已移动到正确的设备。否则,torch.load()
将回退到默认行为,就像未指定map_location
一样。如果
map_location
是一个torch.device
对象或包含设备标记的字符串,它表示应加载所有张量的位置。否则,如果
map_location
是一个字典,它将用于将文件中出现的位置标记(键)重新映射到指定存储位置的标记(值)。用户扩展可以使用
torch.serialization.register_package()
注册他们自己的位置标记以及标记和反序列化方法。- 参数
f (联合[字符串, PathLike, BinaryIO, IO[字节]]) – 类文件对象(必须实现
read()
、readline()
、tell()
和seek()
),或包含文件名的字符串或 os.PathLike 对象map_location (可选[联合[可调用[[Storage, str], Storage], device, str, Dict[str, str]]]) – 一个函数、
torch.device
、字符串或一个字典,指定如何重新映射存储位置pickle_module (可选[任意]) – 用于解序列化元数据和对象的模块(必须与用于序列化文件的
pickle_module
匹配)weights_only (可选[布尔值]) – 指示解序列化器是否应仅限于加载张量、基本类型、字典和通过
torch.serialization.add_safe_globals()
添加的任何类型。mmap (可选[布尔值]) – 指示文件是否应该被内存映射,而不是将所有存储加载到内存中。通常,文件中的张量存储将首先从磁盘移动到 CPU 内存,然后将其移动到保存时标记的位置,或由
map_location
指定的位置。如果最终位置是 CPU,则此第二步是无操作的。当设置mmap
标志时,在第一步中,不会将张量存储从磁盘复制到 CPU 内存,而是将f
内存映射。pickle_load_args (任意) – (仅限 Python 3)传递给
pickle_module.load()
和pickle_module.Unpickler()
的可选关键字参数,例如errors=...
。
- 返回类型
警告
torch.load()
除非将weights_only参数设置为True,否则会隐式使用pickle
模块,该模块已知是不安全的。可以构造恶意的 pickle 数据,这些数据将在解序列化期间执行任意代码。切勿在不安全模式下加载可能来自不受信任来源或可能被篡改的数据。**仅加载您信任的数据**。注意
当您在包含 GPU 张量文件上调用
torch.load()
时,默认情况下,这些张量将加载到 GPU。您可以调用torch.load(.., map_location='cpu')
,然后load_state_dict()
来避免在加载模型检查点时出现 GPU 内存激增。注意
默认情况下,我们将字节字符串解码为
utf-8
。这是为了避免在 Python 3 中加载 Python 2 保存的文件时出现常见的错误情况UnicodeDecodeError: 'ascii' codec can't decode byte 0x...
。如果此默认值不正确,您可以使用额外的encoding
关键字参数指定如何加载这些对象,例如encoding='latin1'
使用latin1
编码将它们解码为字符串,而encoding='bytes'
将它们保留为字节数组,稍后可以使用byte_array.decode(...)
对其进行解码。示例
>>> torch.load("tensors.pt", weights_only=True) # Load all tensors onto the CPU >>> torch.load("tensors.pt", map_location=torch.device("cpu"), weights_only=True) # Load all tensors onto the CPU, using a function >>> torch.load( ... "tensors.pt", map_location=lambda storage, loc: storage, weights_only=True ... ) # Load all tensors onto GPU 1 >>> torch.load( ... "tensors.pt", ... map_location=lambda storage, loc: storage.cuda(1), ... weights_only=True, ... ) # type: ignore[attr-defined] # Map tensors from GPU 1 to GPU 0 >>> torch.load("tensors.pt", map_location={"cuda:1": "cuda:0"}, weights_only=True) # Load tensor from io.BytesIO object # Loading from a buffer setting weights_only=False, warning this can be unsafe >>> with open("tensor.pt", "rb") as f: ... buffer = io.BytesIO(f.read()) >>> torch.load(buffer, weights_only=False) # Load a module with 'ascii' encoding for unpickling # Loading from a module setting weights_only=False, warning this can be unsafe >>> torch.load("module.pt", encoding="ascii", weights_only=False)