torch.jit.trace¶
- torch.jit.trace(func, example_inputs=None, optimize=None, check_trace=True, check_inputs=None, check_tolerance=1e-05, strict=True, _force_outplace=False, _module_class=None, _compilation_unit=<torch.jit.CompilationUnit object>, example_kwarg_inputs=None, _store_inputs=True)[源代码]¶
跟踪一个函数并返回一个可执行的或
ScriptFunction
,它将使用即时编译进行优化。跟踪非常适合仅对
Tensor
和列表、字典以及包含Tensor
的元组进行操作的代码。使用 torch.jit.trace 和 torch.jit.trace_module,您可以将现有的模块或 Python 函数转换为 TorchScript
ScriptFunction
或ScriptModule
。您必须提供示例输入,然后我们会运行该函数,记录对所有张量执行的操作。对独立函数进行记录产生的结果是 ScriptFunction。
对 nn.Module.forward 或 nn.Module 进行记录产生的结果是 ScriptModule。
此模块还包含原始模块拥有的任何参数。
警告
跟踪只能正确记录那些不依赖于数据的函数和模块(例如,不包含对张量中数据的条件语句)并且没有任何未跟踪的外部依赖项(例如,执行输入/输出或访问全局变量)。跟踪只会记录在给定函数在给定张量上运行时执行的操作。因此,返回的 ScriptModule 将始终对任何输入运行相同的已跟踪图。当您的模块期望运行不同的操作集(取决于输入和/或模块状态)时,这有一些重要的含义。例如,
跟踪不会记录任何控制流,例如 if 语句或循环。当此控制流在您的模块中保持一致时,这很好,它通常会内联控制流决策。但是,有时控制流实际上是模型本身的一部分。例如,循环神经网络是针对输入序列(可能动态的)长度进行循环的。
在返回的
ScriptModule
中,在training
和eval
模式下具有不同行为的操作将始终表现得好像它处于跟踪期间所在的模式,无论 ScriptModule 处于哪种模式。
在这些情况下,跟踪将不合适,
脚本化
是更好的选择。如果您跟踪此类模型,则可能会在随后调用模型时静默地获得不正确的结果。跟踪器会在执行可能导致产生不正确跟踪的操作时尝试发出警告。- 参数
func (可调用 或 torch.nn.Module) – 将使用 example_inputs 运行的 Python 函数或 torch.nn.Module。func 参数和返回值必须是张量或(可能嵌套的)包含张量的元组。当传递模块时 torch.jit.trace,只会运行和跟踪
forward
方法(请参阅torch.jit.trace
获取详细信息)。- 关键字参数
example_inputs (元组 或 torch.Tensor 或 None, 可选) – 传递给函数以进行跟踪的示例输入元组。默认值:
None
。应指定此参数或example_kwarg_inputs
。生成的跟踪可以使用不同类型和形状的输入运行,假设跟踪的操作支持这些类型和形状。example_inputs 也可能是一个单独的张量,在这种情况下它将自动包装在一个元组中。当值为 None 时,应指定example_kwarg_inputs
。check_trace (
bool
, 可选) – 检查相同输入是否通过跟踪的代码产生相同的输出。默认值:True
。例如,如果您的网络包含非确定性操作,或者您确定网络即使在检查器失败的情况下也是正确的,您可能希望禁用此功能。check_inputs (列表 的 元组, 可选) – 用于检查跟踪结果是否符合预期的输入参数元组列表。每个元组相当于在
example_inputs
中指定的输入参数集。为了获得最佳结果,请传入一组代表您期望网络看到的输入形状和类型空间的检查输入。如果未指定,则原始example_inputs
用于检查。check_tolerance (浮点数, 可选) – 检查程序中使用的浮点比较容差。这可用于在结果因已知原因(例如运算符融合)在数值上出现偏差的情况下,放宽检查器严格性。
strict (
bool
, 可选) – 以严格模式运行跟踪器还是否(默认值:True
)。仅在您希望跟踪器记录可变容器类型(当前为list
/dict
)并且您确定在问题中使用的容器是constant
结构且不用作控制流(if,for)条件时关闭此选项。example_kwarg_inputs (字典, 可选) – 此参数是一组将传递给函数以进行跟踪的示例输入的关键字参数。默认值:
None
。应指定此参数或example_inputs
。字典将通过跟踪函数的参数名称解包。如果字典的键与跟踪函数的参数名称不匹配,将引发运行时异常。
- 返回值
如果 func 是 nn.Module 或
forward
的 nn.Module,trace 返回一个具有单个forward
方法的ScriptModule
对象,其中包含跟踪的代码。返回的 ScriptModule 将具有与原始nn.Module
相同的子模块和参数集。如果func
是一个独立函数,trace
返回 ScriptFunction。
示例(跟踪函数)
import torch def foo(x, y): return 2 * x + y # Run `foo` with the provided inputs and record the tensor operations traced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3))) # `traced_foo` can now be run with the TorchScript interpreter or saved # and loaded in a Python-free environment
示例(跟踪现有模块)
import torch import torch.nn as nn class Net(nn.Module): def __init__(self) -> None: super().__init__() self.conv = nn.Conv2d(1, 1, 3) def forward(self, x): return self.conv(x) n = Net() example_weight = torch.rand(1, 1, 3, 3) example_forward_input = torch.rand(1, 1, 3, 3) # Trace a specific method and construct `ScriptModule` with # a single `forward` method module = torch.jit.trace(n.forward, example_forward_input) # Trace a module (implicitly traces `forward`) and construct a # `ScriptModule` with a single `forward` method module = torch.jit.trace(n, example_forward_input)