快捷方式

torch.jit.trace_module

torch.jit.trace_module(mod, inputs, optimize=None, check_trace=True, check_inputs=None, check_tolerance=1e-05, strict=True, _force_outplace=False, _module_class=None, _compilation_unit=<torch.jit.CompilationUnit object>, example_inputs_is_kwarg=False, _store_inputs=True)[源代码]

跟踪模块并返回一个可执行的 ScriptModule,该模块将使用即时编译进行优化。

当将模块传递给 torch.jit.trace 时,仅运行和跟踪 forward 方法。使用 trace_module,您可以指定一个方法名称到示例输入的字典来进行跟踪(请参见下面的 inputs 参数)。

有关跟踪的更多信息,请参见 torch.jit.trace

参数
  • mod (torch.nn.Module) – 包含 inputs 中指定的方法名称的 torch.nn.Module。给定的方法将作为单个 ScriptModule 的一部分进行编译。

  • inputs (dict) – 包含按 mod 中方法名称索引的示例输入的字典。在跟踪期间,这些输入将传递给方法,方法名称与输入的键相对应。{ 'forward' : example_forward_input, 'method2': example_method2_input}

关键字参数
  • check_trace (bool,可选) – 检查通过跟踪代码运行的相同输入是否产生相同的输出。默认值:True。例如,如果您确定网络是正确的,即使检查器失败,您可能也希望禁用此功能,因为您的网络包含非确定性操作。

  • check_inputs (list of dicts, optional) – 用于检查跟踪与预期结果是否一致的输入参数字典列表。每个元组都等效于一组将在 inputs 中指定的输入参数。为了获得最佳效果,请传入一组代表您期望网络看到的输入形状和类型空间的检查输入。如果未指定,则使用原始 inputs 进行检查

  • check_tolerance (float, optional) – 在检查器过程中使用的浮点比较容差。如果结果由于已知原因(例如运算符融合)在数值上发生偏差,则可以使用它来降低检查器的严格性。

  • example_inputs_is_kwarg (bool,可选) – 此参数指示示例输入是否为关键字参数包。默认值:False

返回值

一个包含跟踪代码的单个 forward 方法的 ScriptModule 对象。当 functorch.nn.Module 时,返回的 ScriptModule 将与 func 具有相同的子模块和参数集。

示例(跟踪具有多个方法的模块)

import torch
import torch.nn as nn

class Net(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv = nn.Conv2d(1, 1, 3)

    def forward(self, x):
        return self.conv(x)

    def weighted_kernel_sum(self, weight):
        return weight * self.conv.weight


n = Net()
example_weight = torch.rand(1, 1, 3, 3)
example_forward_input = torch.rand(1, 1, 3, 3)

# Trace a specific method and construct `ScriptModule` with
# a single `forward` method
module = torch.jit.trace(n.forward, example_forward_input)

# Trace a module (implicitly traces `forward`) and construct a
# `ScriptModule` with a single `forward` method
module = torch.jit.trace(n, example_forward_input)

# Trace specific methods on a module (specified in `inputs`), constructs
# a `ScriptModule` with `forward` and `weighted_kernel_sum` methods
inputs = {'forward' : example_forward_input, 'weighted_kernel_sum' : example_weight}
module = torch.jit.trace_module(n, inputs)

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发人员的深度教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源