torch.jit.save¶
- torch.jit.save(m, f, _extra_files=None)[源代码]¶
保存此模块的离线版本以供在单独的进程中使用。
保存的模块将序列化此模块的所有方法、子模块、参数和属性。它可以使用
torch::jit::load(filename)
加载到 C++ API 中,或使用torch.jit.load
加载到 Python API 中。为了能够保存模块,它不能调用任何本地 Python 函数。这意味着所有子模块也必须是
ScriptModule
的子类。危险
所有模块,无论其设备如何,在加载期间始终加载到 CPU 上。这与
torch.load()
的语义不同,并且将来可能会更改。- 参数
m – 要保存的
ScriptModule
。f – 类文件对象(必须实现 write 和 flush)或包含文件名的字符串。
_extra_files – 从文件名到内容的映射,这些内容将作为 f 的一部分存储。
注意
torch.jit.save 尝试在不同版本之间保留某些运算符的行为。例如,在 PyTorch 1.5 中,两个整数张量相除执行的是向下取整除法,如果包含该代码的模块在 PyTorch 1.5 中保存并在 PyTorch 1.6 中加载,则其除法行为将保留。但是,在 PyTorch 1.6 中保存的同一模块将无法在 PyTorch 1.5 中加载,因为除法的行为在 1.6 中发生了改变,而 1.5 不知道如何复制 1.6 的行为。
示例: .. testcode
import torch import io class MyModule(torch.nn.Module): def forward(self, x): return x + 10 m = torch.jit.script(MyModule()) # Save to file torch.jit.save(m, 'scriptmodule.pt') # This line is equivalent to the previous m.save("scriptmodule.pt") # Save to io.BytesIO buffer buffer = io.BytesIO() torch.jit.save(m, buffer) # Save with extra files extra_files = {'foo.txt': b'bar'} torch.jit.save(m, 'scriptmodule.pt', _extra_files=extra_files)