快捷方式

torch.jit.freeze

torch.jit.freeze(mod, preserved_attrs=None, optimize_numerics=True)[源代码][源代码]

冻结 ScriptModule,内联子模块和属性作为常量。

冻结 ScriptModule 将克隆它,并尝试将克隆模块的子模块、参数和属性作为常量内联到 TorchScript IR 图中。默认情况下,将保留 forward,以及 preserved_attrs 中指定的属性和方法。此外,任何在保留方法中修改的属性也将被保留。

冻结目前只接受处于 eval 模式的 ScriptModule。

冻结应用通用优化,这将加快您的模型速度,而与机器无关。要使用服务器特定的设置进一步优化,请在冻结后运行 optimize_for_inference

参数
  • mod (ScriptModule) – 要冻结的模块

  • preserved_attrs (Optional[List[str]]) – 除了 forward 方法之外,还要保留的属性列表。在保留方法中修改的属性也将被保留。

  • optimize_numerics (bool) – 如果 True,将运行一组不严格保留数值的优化过程。优化的完整细节可以在 torch.jit.run_frozen_optimizations 中找到。

返回

冻结的 ScriptModule

示例 (冻结带有参数的简单模块)

    def forward(self, input):
        output = self.weight.mm(input)
        output = self.linear(output)
        return output

scripted_module = torch.jit.script(MyModule(2, 3).eval())
frozen_module = torch.jit.freeze(scripted_module)
# parameters have been removed and inlined into the Graph as constants
assert len(list(frozen_module.named_parameters())) == 0
# See the compiled graph as Python code
print(frozen_module.code)

示例 (冻结带有保留属性的模块)

    def forward(self, input):
        self.modified_tensor += 1
        return input + self.modified_tensor

scripted_module = torch.jit.script(MyModule2().eval())
frozen_module = torch.jit.freeze(scripted_module, preserved_attrs=["version"])
# we've manually preserved `version`, so it still exists on the frozen module and can be modified
assert frozen_module.version == 1
frozen_module.version = 2
# `modified_tensor` is detected as being mutated in the forward, so freezing preserves
# it to retain model semantics
assert frozen_module(torch.tensor(1)) == torch.tensor(12)
# now that we've run it once, the next result will be incremented by one
assert frozen_module(torch.tensor(1)) == torch.tensor(13)

注意

也支持冻结子模块属性:frozen_module = torch.jit.freeze(scripted_module, preserved_attrs=[“submodule.version”])

注意

如果您不确定为什么某个属性没有作为常量内联,您可以在 frozen_module.forward.graph 上运行 dump_alias_db,以查看冻结是否检测到该属性正在被修改。

注意

由于冻结使权重成为常量并删除了模块层次结构,因此 to 和其他 nn.Module 方法来操作设备或 dtype 不再起作用。作为一种解决方法,您可以通过在 torch.jit.load 中指定 map_location 来重新映射设备,但是特定于设备的逻辑可能已 baked 到模型中。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得您的问题解答

查看资源