torch.jit.trace¶
- torch.jit.trace(func, example_inputs=None, optimize=None, check_trace=True, check_inputs=None, check_tolerance=1e-05, strict=True, _force_outplace=False, _module_class=None, _compilation_unit=<torch.jit.CompilationUnit object>, example_kwarg_inputs=None, _store_inputs=True)[源代码][源代码]¶
追踪一个函数并返回一个可执行文件或
ScriptFunction
,该函数将使用即时编译进行优化。追踪非常适合仅对
Tensor
及其列表、字典和元组进行操作的代码。使用 torch.jit.trace 和 torch.jit.trace_module,您可以将现有模块或 Python 函数转换为 TorchScript
ScriptFunction
或ScriptModule
。您必须提供示例输入,我们将运行该函数,记录在所有张量上执行的操作。独立函数的追踪结果会生成 ScriptFunction。
nn.Module.forward 或 nn.Module 的追踪结果会生成 ScriptModule。
这个模块也包含原始模块拥有的所有参数。
警告
追踪只正确记录不依赖于数据的函数和模块(例如,在张量中的数据上没有条件判断),并且没有任何未被追踪的外部依赖(例如,执行输入/输出或访问全局变量)。追踪只记录给定函数在给定张量上运行时执行的操作。因此,返回的 ScriptModule 在任何输入上都将运行相同的追踪图。当您的模块根据输入和/或模块状态预期运行不同的操作集时,这会产生一些重要的影响。例如,
追踪不会记录任何控制流,例如 if 语句或循环。当这个控制流在整个模块中是常量时,这很好,它通常会内联控制流决策。但有时控制流实际上是模型本身的一部分。例如,循环网络是根据输入序列(可能是动态的)长度进行的循环。
在返回的
ScriptModule
中,在training
和eval
模式下行为不同的操作,无论 ScriptModule 处于哪种模式,都将始终表现得如同追踪时的模式一样。
在这些情况下,追踪可能不适用,而
scripting
是更好的选择。如果您追踪此类模型,在后续调用模型时可能会默默地得到错误的结果。追踪器在执行可能导致生成错误追踪的操作时会尝试发出警告。- 参数
func (可调用对象 或 torch.nn.Module) – 一个 Python 函数或 torch.nn.Module,将使用 example_inputs 运行。 func 的参数和返回值必须是张量或包含张量的(可能是嵌套的)元组。当一个模块被传递给 torch.jit.trace 时,只运行并追踪其
forward
方法(详情参见torch.jit.trace
)。- 关键字参数
example_inputs (元组 或 torch.Tensor 或 None, 可选) – 在追踪时将传递给函数的示例输入元组。默认值:
None
。必须指定此参数或example_kwarg_inputs
。假设追踪的操作支持不同的类型和形状,结果追踪可以与具有不同类型和形状的输入一起运行。example_inputs 也可以是单个 Tensor,在这种情况下它会自动被包装在一个元组中。当值为 None 时,应指定example_kwarg_inputs
。check_trace (
bool
, 可选) – 检查相同的输入通过追踪代码是否产生相同的输出。默认值:True
。如果您确定网络正确,即使检查器失败,或者您的网络包含非确定性操作,您可能想要禁用此选项。check_inputs (元组 列表, 可选) – 用于检查追踪结果与预期是否一致的输入参数元组列表。每个元组等同于
example_inputs
中指定的一组输入参数。为获得最佳结果,请传入一组代表您预期网络将看到的输入形状和类型空间的检查输入。如果未指定,则使用原始的example_inputs
进行检查。check_tolerance (float, 可选) – 在检查过程中使用的浮点比较容差。当结果由于已知原因(如运算符融合)导致数值偏差时,可以使用此参数放宽检查器的严格性。
strict (
bool
, 可选) – 是否以严格模式运行追踪器(默认值:True
)。仅当您希望追踪器记录您的可变容器类型(目前是list
/dict
)并且您确定您在问题中使用的容器是constant
结构且不被用作控制流(if, for)条件时,才禁用此选项。example_kwarg_inputs (字典, 可选) – 此参数是一组关键字参数形式的示例输入,将在追踪时传递给函数。默认值:
None
。必须指定此参数或example_inputs
。该字典将通过被追踪函数的参数名称进行解包。如果字典的键与被追踪函数的参数名称不匹配,将引发运行时异常。
- 返回值
如果 func 是 nn.Module 或 nn.Module 的
forward
方法,trace 返回一个ScriptModule
对象,该对象具有包含追踪代码的单个forward
方法。返回的 ScriptModule 将拥有与原始nn.Module
相同的子模块和参数集合。如果func
是一个独立函数,trace
返回 ScriptFunction。
示例(追踪函数)
import torch def foo(x, y): return 2 * x + y # Run `foo` with the provided inputs and record the tensor operations traced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3))) # `traced_foo` can now be run with the TorchScript interpreter or saved # and loaded in a Python-free environment
示例(追踪现有模块)
import torch import torch.nn as nn class Net(nn.Module): def __init__(self) -> None: super().__init__() self.conv = nn.Conv2d(1, 1, 3) def forward(self, x): return self.conv(x) n = Net() example_weight = torch.rand(1, 1, 3, 3) example_forward_input = torch.rand(1, 1, 3, 3) # Trace a specific method and construct `ScriptModule` with # a single `forward` method module = torch.jit.trace(n.forward, example_forward_input) # Trace a module (implicitly traces `forward`) and construct a # `ScriptModule` with a single `forward` method module = torch.jit.trace(n, example_forward_input)