快捷方式

TorchScript

TorchScript 是一种从 PyTorch 代码创建可序列化和可优化模型的方法。任何 TorchScript 程序都可以从 Python 进程中保存,并在没有 Python 依赖项的进程中加载。

我们提供了将模型从纯 Python 程序增量转换为可以在没有 Python 的情况下运行的 TorchScript 程序的工具,例如在独立的 C++ 程序中。这使得可以使用 Python 中熟悉的工具在 PyTorch 中训练模型,然后通过 TorchScript 将模型导出到生产环境中,在这些环境中,Python 程序可能由于性能和多线程原因而处于劣势。

有关 TorchScript 的简单介绍,请参阅TorchScript 简介教程。

有关将 PyTorch 模型转换为 TorchScript 并在 C++ 中运行它的端到端示例,请参阅在 C++ 中加载 PyTorch 模型教程。

创建 TorchScript 代码

script

脚本化函数。

trace

跟踪函数并返回将使用即时编译优化的可执行文件或ScriptFunction

script_if_tracing

在跟踪期间首次调用fn时编译它。

trace_module

跟踪模块并返回将使用即时编译优化的可执行ScriptModule

fork

创建一个执行func的异步任务,并引用此执行结果的值。

wait

强制完成torch.jit.Future[T]异步任务,返回任务的结果。

ScriptModule

C++ torch::jit::Module 的包装器,具有方法、属性和参数。

ScriptFunction

在功能上等效于ScriptModule,但表示单个函数,并且没有任何属性或参数。

freeze

冻结 ScriptModule,将子模块和属性内联为常量。

optimize_for_inference

执行一组优化过程,以优化模型以用于推理。

enable_onednn_fusion

根据参数enabled启用或禁用 onednn JIT 融合。

onednn_fusion_enabled

返回是否启用了 onednn JIT 融合。

set_fusion_strategy

设置在融合过程中可能出现的专业化类型和数量。

strict_fusion

如果在推理中并非所有节点都已融合,或者在训练中进行了符号微分,则给出错误。

save

保存此模块的脱机版本,以在单独的进程中使用。

load

加载先前使用torch.jit.save保存的ScriptModuleScriptFunction

ignore

此装饰器向编译器指示应忽略某个函数或方法,并将其保留为 Python 函数。

unused

此装饰器向编译器指示应忽略某个函数或方法,并将其替换为引发异常。

interface

使用装饰器来注释不同类型的类或模块。

isinstance

在 TorchScript 中提供容器类型细化。

Attribute

此方法是一个传递函数,它返回value,主要用于向 TorchScript 编译器指示左侧表达式是类型为type的类实例属性。

annotate

用于在 TorchScript 编译器中提供the_value的类型。

混合跟踪和脚本化

在许多情况下,跟踪或脚本化都是将模型转换为 TorchScript 的更简单方法。可以组合跟踪和脚本化以适应模型部分的特定要求。

脚本化函数可以调用跟踪函数。当您需要围绕简单的顺序模型使用控制流时,这特别有用。例如,序列到序列模型的波束搜索通常用脚本编写,但可以调用使用跟踪生成的编码器模块。

示例(在脚本中调用跟踪函数)

import torch

def foo(x, y):
    return 2 * x + y

traced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3)))

@torch.jit.script
def bar(x):
    return traced_foo(x, x)

跟踪函数可以调用脚本函数。当模型的一小部分需要一些控制流,而模型的大部分只是一个前馈网络时,这非常有用。由跟踪函数调用的脚本函数内的控制流可以被正确保留。

示例(在跟踪函数中调用脚本函数)

import torch

@torch.jit.script
def foo(x, y):
    if x.max() > y.max():
        r = x
    else:
        r = y
    return r


def bar(x, y, z):
    return foo(x, y) + z

traced_bar = torch.jit.trace(bar, (torch.rand(3), torch.rand(3), torch.rand(3)))

这种组合也适用于 nn.Module,它可以用于生成一个使用跟踪的子模块,该子模块可以从脚本模块的方法中调用。

示例(使用跟踪模块)

import torch
import torchvision

class MyScriptModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.means = torch.nn.Parameter(torch.tensor([103.939, 116.779, 123.68])
                                        .resize_(1, 3, 1, 1))
        self.resnet = torch.jit.trace(torchvision.models.resnet18(),
                                      torch.rand(1, 3, 224, 224))

    def forward(self, input):
        return self.resnet(input - self.means)

my_script_module = torch.jit.script(MyScriptModule())

TorchScript 语言

TorchScript 是 Python 的一个静态类型子集,因此许多 Python 特性可以直接应用于 TorchScript。有关详细信息,请参阅完整的 TorchScript 语言参考

内置函数和模块

TorchScript 支持使用大多数 PyTorch 函数和许多 Python 内置函数。有关受支持函数的完整参考,请参阅 TorchScript 内置函数

PyTorch 函数和模块

TorchScript 支持 PyTorch 提供的张量和神经网络函数的一个子集。Tensor 上的大多数方法以及 torch 命名空间中的函数,torch.nn.functional 中的所有函数以及 torch.nn 中的大多数模块都在 TorchScript 中受支持。

有关不受支持的 PyTorch 函数和模块的列表,请参阅 TorchScript 不支持的 PyTorch 结构

Python 函数和模块

TorchScript 支持许多 Python 的 内置函数math 模块也受支持(有关详细信息,请参阅 math 模块),但不支持其他 Python 模块(内置或第三方)。

Python 语言参考比较

有关受支持的 Python 功能的完整列表,请参阅 Python 语言参考覆盖范围

调试

禁用 JIT 进行调试

PYTORCH_JIT

设置环境变量 PYTORCH_JIT=0 将禁用所有脚本和跟踪注释。如果您的某个 TorchScript 模型中存在难以调试的错误,则可以使用此标志强制所有内容都使用原生 Python 运行。由于使用此标志禁用了 TorchScript(脚本和跟踪),因此您可以使用 pdb 等工具来调试模型代码。例如

@torch.jit.script
def scripted_fn(x : torch.Tensor):
    for i in range(12):
        x = x + x
    return x

def fn(x):
    x = torch.neg(x)
    import pdb; pdb.set_trace()
    return scripted_fn(x)

traced_fn = torch.jit.trace(fn, (torch.rand(4, 5),))
traced_fn(torch.rand(3, 4))

使用 pdb 调试此脚本是有效的,除非我们调用 @torch.jit.script 函数。我们可以全局禁用 JIT,以便我们可以将 @torch.jit.script 函数作为普通的 Python 函数调用,而不是对其进行编译。如果上面的脚本名为 disable_jit_example.py,我们可以像这样调用它

$ PYTORCH_JIT=0 python disable_jit_example.py

并且我们将能够像使用普通 Python 函数一样进入 @torch.jit.script 函数。要为特定函数禁用 TorchScript 编译器,请参阅 @torch.jit.ignore

检查代码

TorchScript 为所有 ScriptModule 实例提供了一个代码美化打印器。此美化打印器将脚本方法的代码解释为有效的 Python 语法。例如

@torch.jit.script
def foo(len):
    # type: (int) -> torch.Tensor
    rv = torch.zeros(3, 4)
    for i in range(len):
        if i < 10:
            rv = rv - 1.0
        else:
            rv = rv + 1.0
    return rv

print(foo.code)

具有单个 forward 方法的 ScriptModule 将具有属性 code,您可以使用它来检查 ScriptModule 的代码。如果 ScriptModule 具有多个方法,则需要访问方法本身上的 .code,而不是模块。我们可以通过访问 .foo.code 来检查 ScriptModule 上名为 foo 的方法的代码。上面的示例产生以下输出

def foo(len: int) -> Tensor:
    rv = torch.zeros([3, 4], dtype=None, layout=None, device=None, pin_memory=None)
    rv0 = rv
    for i in range(len):
        if torch.lt(i, 10):
            rv1 = torch.sub(rv0, 1., 1)
        else:
            rv1 = torch.add(rv0, 1., 1)
        rv0 = rv1
    return rv0

这是 TorchScript 对 forward 方法代码的编译结果。您可以使用它来确保 TorchScript(跟踪或脚本)正确捕获了您的模型代码。

解释图形

TorchScript 还具有一种比代码美化打印器更低级的表示形式,即 IR 图。

TorchScript 使用静态单赋值 (SSA) 中间表示 (IR) 来表示计算。这种格式的指令包括 ATen(PyTorch 的 C++ 后端)运算符和其他基本运算符,包括用于循环和条件语句的控制流运算符。例如

@torch.jit.script
def foo(len):
    # type: (int) -> torch.Tensor
    rv = torch.zeros(3, 4)
    for i in range(len):
        if i < 10:
            rv = rv - 1.0
        else:
            rv = rv + 1.0
    return rv

print(foo.graph)

graph 遵循 检查代码 部分中描述的关于 forward 方法查找的相同规则。

上面的示例脚本生成以下图形

graph(%len.1 : int):
  %24 : int = prim::Constant[value=1]()
  %17 : bool = prim::Constant[value=1]() # test.py:10:5
  %12 : bool? = prim::Constant()
  %10 : Device? = prim::Constant()
  %6 : int? = prim::Constant()
  %1 : int = prim::Constant[value=3]() # test.py:9:22
  %2 : int = prim::Constant[value=4]() # test.py:9:25
  %20 : int = prim::Constant[value=10]() # test.py:11:16
  %23 : float = prim::Constant[value=1]() # test.py:12:23
  %4 : int[] = prim::ListConstruct(%1, %2)
  %rv.1 : Tensor = aten::zeros(%4, %6, %6, %10, %12) # test.py:9:10
  %rv : Tensor = prim::Loop(%len.1, %17, %rv.1) # test.py:10:5
    block0(%i.1 : int, %rv.14 : Tensor):
      %21 : bool = aten::lt(%i.1, %20) # test.py:11:12
      %rv.13 : Tensor = prim::If(%21) # test.py:11:9
        block0():
          %rv.3 : Tensor = aten::sub(%rv.14, %23, %24) # test.py:12:18
          -> (%rv.3)
        block1():
          %rv.6 : Tensor = aten::add(%rv.14, %23, %24) # test.py:14:18
          -> (%rv.6)
      -> (%17, %rv.13)
  return (%rv)

以指令 %rv.1 : Tensor = aten::zeros(%4, %6, %6, %10, %12) # test.py:9:10 为例。

  • %rv.1 : Tensor 表示我们将输出分配给一个名为 rv.1 的(唯一)值,该值的类型为 Tensor,并且我们不知道其具体形状。

  • aten::zeros 是运算符(等效于 torch.zeros),输入列表 (%4, %6, %6, %10, %12) 指定应将范围内的哪些值作为输入传递。可以在 内置函数 中找到 aten::zeros 等内置函数的架构。

  • # test.py:9:10 是生成此指令的原始源文件中的位置。在本例中,它是一个名为 test.py 的文件,位于第 9 行,第 10 个字符。

请注意,运算符还可以具有关联的 ,即 prim::Loopprim::If 运算符。在图形打印输出中,这些运算符的格式与其等效的源代码形式相对应,以便于调试。

可以按如下所示检查图形,以确认 ScriptModule 描述的计算是正确的,既可以自动进行,也可以手动进行,如下所述。

跟踪器

跟踪边缘情况

在某些边缘情况下,给定 Python 函数/模块的跟踪不能代表底层代码。这些情况可能包括

  • 跟踪依赖于输入(例如,张量形状)的控制流

  • 跟踪张量视图的原地操作(例如,赋值左侧的索引)

请注意,这些情况在将来实际上可能是可以跟踪的。

自动跟踪检查

自动捕获跟踪中的许多错误的一种方法是在 torch.jit.trace() API 上使用 check_inputscheck_inputs 接受一个输入元组列表,这些输入将用于重新跟踪计算并验证结果。例如

def loop_in_traced_fn(x):
    result = x[0]
    for i in range(x.size(0)):
        result = result * x[i]
    return result

inputs = (torch.rand(3, 4, 5),)
check_inputs = [(torch.rand(4, 5, 6),), (torch.rand(2, 3, 4),)]

traced = torch.jit.trace(loop_in_traced_fn, inputs, check_inputs=check_inputs)

为我们提供以下诊断信息

ERROR: Graphs differed across invocations!
Graph diff:

            graph(%x : Tensor) {
            %1 : int = prim::Constant[value=0]()
            %2 : int = prim::Constant[value=0]()
            %result.1 : Tensor = aten::select(%x, %1, %2)
            %4 : int = prim::Constant[value=0]()
            %5 : int = prim::Constant[value=0]()
            %6 : Tensor = aten::select(%x, %4, %5)
            %result.2 : Tensor = aten::mul(%result.1, %6)
            %8 : int = prim::Constant[value=0]()
            %9 : int = prim::Constant[value=1]()
            %10 : Tensor = aten::select(%x, %8, %9)
        -   %result : Tensor = aten::mul(%result.2, %10)
        +   %result.3 : Tensor = aten::mul(%result.2, %10)
        ?          ++
            %12 : int = prim::Constant[value=0]()
            %13 : int = prim::Constant[value=2]()
            %14 : Tensor = aten::select(%x, %12, %13)
        +   %result : Tensor = aten::mul(%result.3, %14)
        +   %16 : int = prim::Constant[value=0]()
        +   %17 : int = prim::Constant[value=3]()
        +   %18 : Tensor = aten::select(%x, %16, %17)
        -   %15 : Tensor = aten::mul(%result, %14)
        ?     ^                                 ^
        +   %19 : Tensor = aten::mul(%result, %18)
        ?     ^                                 ^
        -   return (%15);
        ?             ^
        +   return (%19);
        ?             ^
            }

此消息向我们表明,在我们第一次跟踪计算时和在我们使用 check_inputs 跟踪计算时,计算结果不同。实际上,loop_in_traced_fn 体内的循环依赖于输入 x 的形状,因此当我们尝试使用具有不同形状的另一个 x 时,跟踪结果会有所不同。

在这种情况下,可以使用 torch.jit.script() 来捕获这种依赖于数据的控制流,例如

def fn(x):
    result = x[0]
    for i in range(x.size(0)):
        result = result * x[i]
    return result

inputs = (torch.rand(3, 4, 5),)
check_inputs = [(torch.rand(4, 5, 6),), (torch.rand(2, 3, 4),)]

scripted_fn = torch.jit.script(fn)
print(scripted_fn.graph)
#print(str(scripted_fn.graph).strip())

for input_tuple in [inputs] + check_inputs:
    torch.testing.assert_close(fn(*input_tuple), scripted_fn(*input_tuple))

产生

graph(%x : Tensor) {
    %5 : bool = prim::Constant[value=1]()
    %1 : int = prim::Constant[value=0]()
    %result.1 : Tensor = aten::select(%x, %1, %1)
    %4 : int = aten::size(%x, %1)
    %result : Tensor = prim::Loop(%4, %5, %result.1)
    block0(%i : int, %7 : Tensor) {
        %10 : Tensor = aten::select(%x, %1, %i)
        %result.2 : Tensor = aten::mul(%7, %10)
        -> (%5, %result.2)
    }
    return (%result);
}

跟踪器警告

对于跟踪计算中的几种有问题的模式,跟踪器会生成警告。例如,获取包含对 Tensor 切片(视图)进行原地赋值的函数的跟踪

def fill_row_zero(x):
    x[0] = torch.rand(*x.shape[1:2])
    return x

traced = torch.jit.trace(fill_row_zero, (torch.rand(3, 4),))
print(traced.graph)

产生几个警告和一个简单地返回输入的图形

fill_row_zero.py:4: TracerWarning: There are 2 live references to the data region being modified when tracing in-place operator copy_ (possibly due to an assignment). This might cause the trace to be incorrect, because all other views that also reference this data will not reflect this change in the trace! On the other hand, if all other views use the same memory chunk, but are disjoint (e.g. are outputs of torch.split), this might still be safe.
    x[0] = torch.rand(*x.shape[1:2])
fill_row_zero.py:6: TracerWarning: Output nr 1. of the traced function does not match the corresponding output of the Python function. Detailed error:
Not within tolerance rtol=1e-05 atol=1e-05 at input[0, 1] (0.09115803241729736 vs. 0.6782537698745728) and 3 other locations (33.00%)
    traced = torch.jit.trace(fill_row_zero, (torch.rand(3, 4),))
graph(%0 : Float(3, 4)) {
    return (%0);
}

我们可以通过修改代码来解决这个问题,不使用原地更新,而是使用 torch.cat 异地构建结果张量

def fill_row_zero(x):
    x = torch.cat((torch.rand(1, *x.shape[1:2]), x[1:2]), dim=0)
    return x

traced = torch.jit.trace(fill_row_zero, (torch.rand(3, 4),))
print(traced.graph)

常见问题解答

问:我想在 GPU 上训练模型并在 CPU 上进行推理。最佳实践是什么?

首先将模型从 GPU 转换为 CPU,然后保存,如下所示

cpu_model = gpu_model.cpu()
sample_input_cpu = sample_input_gpu.cpu()
traced_cpu = torch.jit.trace(cpu_model, sample_input_cpu)
torch.jit.save(traced_cpu, "cpu.pt")

traced_gpu = torch.jit.trace(gpu_model, sample_input_gpu)
torch.jit.save(traced_gpu, "gpu.pt")

# ... later, when using the model:

if use_gpu:
  model = torch.jit.load("gpu.pt")
else:
  model = torch.jit.load("cpu.pt")

model(input)

建议这样做,因为跟踪器可能会观察到在特定设备上创建了张量,因此转换已加载的模型可能会产生意外的影响。在保存模型*之前*转换模型可确保跟踪器具有正确的设备信息。

问:如何在 ScriptModule 上存储属性?

假设我们有一个类似这样的模型

import torch

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.x = 2

    def forward(self):
        return self.x

m = torch.jit.script(Model())

如果实例化 Model,则会导致编译错误,因为编译器不知道 x。有 4 种方法可以将 ScriptModule 上的属性告知编译器

1. nn.Parameter - 包装在 nn.Parameter 中的值将像在 nn.Module 上一样工作

2. register_buffer - 包装在 register_buffer 中的值将像在 nn.Module 上一样工作。这等效于类型为 Tensor 的属性(见 4)。

3. 常量 - 将类成员注释为 Final(或在类定义级别将其添加到名为 __constants__ 的列表中)会将包含的名称标记为常量。常量直接保存在模型的代码中。有关详细信息,请参阅内置常量

4. 属性 - 可以将 支持的类型 的值添加为可变属性。大多数类型都可以推断出来,但有些类型可能需要指定,有关详细信息,请参阅模块属性

问:我想跟踪模块的方法,但我一直收到此错误

运行时错误: 无法 插入 需要 梯度 张量 作为 常量。 考虑 将其 设为 参数 输入, 分离 梯度

此错误通常意味着您要跟踪的方法使用了模块的参数,并且您传递的是模块的方法而不是模块实例(例如 my_module_instance.forwardmy_module_instance)。

  • 使用模块的方法调用 trace 会将模块参数(可能需要梯度)捕获为**常量**。

  • 另一方面,使用模块实例(例如 my_module)调用 trace 会创建一个新模块并将参数正确复制到新模块中,因此它们可以在需要时累积梯度。

要跟踪模块上的特定方法,请参阅 torch.jit.trace_module

已知问题

如果您将 Sequential 与 TorchScript 一起使用,则某些 Sequential 子模块的输入可能会被错误地推断为 Tensor,即使它们有其他注释。规范的解决方案是对 nn.Sequential 进行子类化,并使用正确类型的输入重新声明 forward

附录

迁移到 PyTorch 1.2 递归脚本 API

本节详细介绍了 PyTorch 1.2 中对 TorchScript 的更改。如果您是 TorchScript 的新手,可以跳过本节。在 PyTorch 1.2 中,TorchScript API 主要有两处更改。

1. torch.jit.script 现在将尝试递归编译它遇到的函数、方法和类。一旦调用 torch.jit.script,编译就是“选择退出”,而不是“选择加入”。

2. torch.jit.script(nn_module_instance) 现在是创建 ScriptModule 的首选方法,而不是继承自 torch.jit.ScriptModule。这些更改结合在一起,提供了一个更简单、更易于使用的 API,用于将 nn.Module 转换为 ScriptModule,准备在非 Python 环境中进行优化和执行。

新的用法如下所示

import torch
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

my_model = Model()
my_scripted_model = torch.jit.script(my_model)
  • 默认情况下编译模块的 forward。从 forward 调用的方法将按照它们在 forward 中的使用顺序进行延迟编译。

  • 要编译不是从 forward 调用的 forward 以外的方法,请添加 @torch.jit.export

  • 要阻止编译器编译方法,请添加 @torch.jit.ignore@torch.jit.unused@ignore

  • 方法保留为对 Python 的调用,@unused 将其替换为异常。 @ignored 无法导出;@unused 可以。

  • 大多数属性类型都可以推断出来,因此不需要 torch.jit.Attribute。对于空容器类型,请使用 PEP 526 样式 的类注释来注释它们的类型。

  • 可以使用 Final 类注释标记常量,而不是将成员的名称添加到 __constants__ 中。

  • Python 3 类型提示可以代替 torch.jit.annotate 使用

由于这些更改,以下项被视为已弃用,不应出现在新代码中
  • @torch.jit.script_method 装饰器

  • 继承自 torch.jit.ScriptModule 的类

  • torch.jit.Attribute 包装类

  • __constants__ 数组

  • torch.jit.annotate 函数

模块

警告

@torch.jit.ignore 注释的行为在 PyTorch 1.2 中发生了变化。在 PyTorch 1.2 之前,@ignore 装饰器用于使函数或方法可以从导出的代码中调用。要恢复此功能,请使用 @torch.jit.unused()@torch.jit.ignore 现在等效于 @torch.jit.ignore(drop=False)。有关详细信息,请参阅 @torch.jit.ignore@torch.jit.unused

当传递给 torch.jit.script 函数时,torch.nn.Module 的数据将复制到 ScriptModule 中,并且 TorchScript 编译器会编译该模块。默认情况下编译模块的 forward。从 forward 调用的方法将按照它们在 forward 中的使用顺序进行延迟编译,任何 @torch.jit.export 方法也是如此。

torch.jit.export(fn)[源代码]

此装饰器指示 nn.Module 上的方法用作 ScriptModule 的入口点,并且应该进行编译。

隐式地假定 forward 是一个入口点,因此它不需要此装饰器。从 forward 调用的函数和方法将在编译器看到它们时进行编译,因此它们也不需要此装饰器。

示例(在方法上使用 @torch.jit.export

import torch
import torch.nn as nn

class MyModule(nn.Module):
    def implicitly_compiled_method(self, x):
        return x + 99

    # `forward` is implicitly decorated with `@torch.jit.export`,
    # so adding it here would have no effect
    def forward(self, x):
        return x + 10

    @torch.jit.export
    def another_forward(self, x):
        # When the compiler sees this call, it will compile
        # `implicitly_compiled_method`
        return self.implicitly_compiled_method(x)

    def unused_method(self, x):
        return x - 20

# `m` will contain compiled methods:
#     `forward`
#     `another_forward`
#     `implicitly_compiled_method`
# `unused_method` will not be compiled since it was not called from
# any compiled methods and wasn't decorated with `@torch.jit.export`
m = torch.jit.script(MyModule())

函数

函数的变化不大,如果需要,可以使用 @torch.jit.ignoretorch.jit.unused 对其进行装饰。

# Same behavior as pre-PyTorch 1.2
@torch.jit.script
def some_fn():
    return 2

# Marks a function as ignored, if nothing
# ever calls it then this has no effect
@torch.jit.ignore
def some_fn2():
    return 2

# As with ignore, if nothing calls it then it has no effect.
# If it is called in script it is replaced with an exception.
@torch.jit.unused
def some_fn3():
  import pdb; pdb.set_trace()
  return 4

# Doesn't do anything, this function is already
# the main entry point
@torch.jit.export
def some_fn4():
    return 2

TorchScript 类

警告

TorchScript 对类的支持还处于实验阶段。目前,它最适合于简单的类记录类型(例如带有方法的 NamedTuple)。

用户定义的 TorchScript 类 中的所有内容默认都会被导出,如果需要,可以使用 @torch.jit.ignore 修饰函数。

属性

TorchScript 编译器需要知道模块属性的类型。大多数类型可以从成员的值推断出来。空列表和字典无法推断其类型,必须使用 PEP 526 风格 的类注释来标注其类型。如果无法推断类型并且没有显式注释,则不会将其作为属性添加到生成的 ScriptModule 中。

旧 API

from typing import Dict
import torch

class MyModule(torch.jit.ScriptModule):
    def __init__(self):
        super().__init__()
        self.my_dict = torch.jit.Attribute({}, Dict[str, int])
        self.my_int = torch.jit.Attribute(20, int)

m = MyModule()

新 API

from typing import Dict

class MyModule(torch.nn.Module):
    my_dict: Dict[str, int]

    def __init__(self):
        super().__init__()
        # This type cannot be inferred and must be specified
        self.my_dict = {}

        # The attribute type here is inferred to be `int`
        self.my_int = 20

    def forward(self):
        pass

m = torch.jit.script(MyModule())

常量

Final 类型构造函数可用于将成员标记为常量。如果成员未标记为常量,它们将作为属性复制到生成的 ScriptModule 中。如果已知值是固定的,则使用 Final 可以优化代码并提高类型安全性。

旧 API

class MyModule(torch.jit.ScriptModule):
    __constants__ = ['my_constant']

    def __init__(self):
        super().__init__()
        self.my_constant = 2

    def forward(self):
        pass
m = MyModule()

新 API

from typing import Final

class MyModule(torch.nn.Module):

    my_constant: Final[int]

    def __init__(self):
        super().__init__()
        self.my_constant = 2

    def forward(self):
        pass

m = torch.jit.script(MyModule())

变量

容器的类型默认为 Tensor 并且不可选(有关详细信息,请参阅默认类型)。以前,使用 torch.jit.annotate 告诉 TorchScript 编译器类型应该是什么。现在支持 Python 3 风格的类型提示。

import torch
from typing import Dict, Optional

@torch.jit.script
def make_dict(flag: bool):
    x: Dict[str, int] = {}
    x['hi'] = 2
    b: Optional[int] = None
    if flag:
        b = 2
    return x, b

融合后端

有几个融合后端可用于优化 TorchScript 执行。CPU 上的默认融合器是 NNC,它可以对 CPU 和 GPU 执行融合。GPU 上的默认融合器是 NVFuser,它支持更广泛的运算符,并且已证明生成的内核具有更高的吞吐量。有关使用和调试的更多详细信息,请参阅 NVFuser 文档

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的答案

查看资源