TorchScript¶
TorchScript 是一种从 PyTorch 代码创建可序列化和可优化模型的方法。任何 TorchScript 程序都可以从 Python 进程中保存,并在没有 Python 依赖的进程中加载。
我们提供了工具,可以将模型从纯 Python 程序逐步转换为可独立于 Python 运行的 TorchScript 程序,例如在独立的 C++ 程序中。这使得可以使用 Python 中熟悉的工具在 PyTorch 中训练模型,然后通过 TorchScript 将模型导出到生产环境,在生产环境中,由于性能和多线程原因,Python 程序可能不利。
关于 TorchScript 的温和介绍,请参阅TorchScript 入门教程。
关于将 PyTorch 模型转换为 TorchScript 并在 C++ 中运行的端到端示例,请参阅在 C++ 中加载 PyTorch 模型教程。
创建 TorchScript 代码¶
对函数进行脚本化。 |
|
跟踪一个函数并返回一个可执行文件或 |
|
在跟踪期间首次调用 |
|
跟踪一个模块并返回一个可执行的 |
|
创建一个执行 func 的异步任务,并返回此执行结果值的引用。 |
|
强制完成 torch.jit.Future[T] 异步任务,返回任务结果。 |
|
C++ torch::jit::Module 的包装器,包含方法、属性和参数。 |
|
功能上等同于 |
|
冻结 ScriptModule,将子模块和属性内联为常量。 |
|
执行一系列优化过程,以优化模型用于推理。 |
|
根据参数 enabled 启用或禁用 onednn JIT 融合。 |
|
返回 onednn JIT 融合是否已启用。 |
|
设置融合期间可能发生的特化类型和数量。 |
|
如果在推理中并非所有节点都被融合,或者在训练中并非所有节点都被符号化微分,则报错。 |
|
保存此模块的离线版本,以便在单独的进程中使用。 |
|
加载先前使用 |
|
此装饰器指示编译器应忽略某个函数或方法,并将其保留为 Python 函数。 |
|
此装饰器指示编译器应忽略某个函数或方法,并将其替换为抛出异常。 |
|
用于装饰以注释不同类型的类或模块。 |
|
在 TorchScript 中提供容器类型细化。 |
|
此方法是一个直通函数,返回 value,主要用于向 TorchScript 编译器指示左侧表达式是一个类型为 type 的类实例属性。 |
|
用于在 TorchScript 编译器中指定 the_value 的类型。 |
混合跟踪和脚本¶
在许多情况下,跟踪或脚本化是转换模型到 TorchScript 的更简单方法。跟踪和脚本化可以组合使用,以适应模型中特定部分的需求。
脚本化函数可以调用跟踪函数。这在需要围绕简单前馈模型使用控制流时特别有用。例如,序列到序列模型的束搜索通常以脚本编写,但可以调用使用跟踪生成的编码器模块。
示例(在脚本中调用跟踪函数)
import torch
def foo(x, y):
return 2 * x + y
traced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3)))
@torch.jit.script
def bar(x):
return traced_foo(x, x)
跟踪函数可以调用脚本函数。这在模型的很小一部分需要一些控制流,即使模型的大部分只是前馈网络时非常有用。跟踪函数调用的脚本函数内部的控制流可以正确保留。
示例(在跟踪函数中调用脚本函数)
import torch
@torch.jit.script
def foo(x, y):
if x.max() > y.max():
r = x
else:
r = y
return r
def bar(x, y, z):
return foo(x, y) + z
traced_bar = torch.jit.trace(bar, (torch.rand(3), torch.rand(3), torch.rand(3)))
这种组合也适用于 nn.Module
,可以用于使用跟踪生成一个子模块,该子模块可以从脚本模块的方法中调用。
示例(使用跟踪模块)
import torch
import torchvision
class MyScriptModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.means = torch.nn.Parameter(torch.tensor([103.939, 116.779, 123.68])
.resize_(1, 3, 1, 1))
self.resnet = torch.jit.trace(torchvision.models.resnet18(),
torch.rand(1, 3, 224, 224))
def forward(self, input):
return self.resnet(input - self.means)
my_script_module = torch.jit.script(MyScriptModule())
TorchScript 语言¶
TorchScript 是 Python 的一个静态类型子集,因此许多 Python 特性直接适用于 TorchScript。有关详细信息,请参阅完整的 TorchScript 语言参考。
内置函数和模块¶
TorchScript 支持使用大多数 PyTorch 函数和许多 Python 内置函数。有关支持函数的完整参考,请参阅 TorchScript 内置函数。
PyTorch 函数和模块¶
TorchScript 支持 PyTorch 提供的部分张量和神经网络函数。Tensor 的大多数方法以及 torch
命名空间中的函数、torch.nn.functional
中的所有函数以及 torch.nn
中的大多数模块都受 TorchScript 支持。
有关不支持的 PyTorch 函数和模块的列表,请参阅 TorchScript 不支持的 PyTorch 结构。
Python 函数和模块¶
TorchScript 支持许多 Python 的内置函数。还支持 math
模块(详情请参阅 math 模块),但不支持其他 Python 模块(内置或第三方)。
Python 语言参考比较¶
有关支持的 Python 特性的完整列表,请参阅 Python 语言参考覆盖范围。
调试¶
禁用 JIT 进行调试¶
- PYTORCH_JIT¶
设置环境变量 PYTORCH_JIT=0
将禁用所有脚本和跟踪注解。如果在您的某个 TorchScript 模型中存在难以调试的错误,您可以使用此标志强制所有内容使用原生 Python 运行。由于使用此标志禁用了 TorchScript(脚本化和跟踪),您可以使用 pdb
等工具来调试模型代码。例如
@torch.jit.script
def scripted_fn(x : torch.Tensor):
for i in range(12):
x = x + x
return x
def fn(x):
x = torch.neg(x)
import pdb; pdb.set_trace()
return scripted_fn(x)
traced_fn = torch.jit.trace(fn, (torch.rand(4, 5),))
traced_fn(torch.rand(3, 4))
使用 pdb
调试此脚本是可行的,但当我们调用 @torch.jit.script
函数时除外。我们可以全局禁用 JIT,以便可以将 @torch.jit.script
函数作为普通 Python 函数调用而不对其进行编译。如果上述脚本名为 disable_jit_example.py
,我们可以像这样调用它
$ PYTORCH_JIT=0 python disable_jit_example.py
然后我们将能够像进入普通 Python 函数一样进入 @torch.jit.script
函数。要禁用 TorchScript 编译器对特定函数的编译,请参阅 @torch.jit.ignore
。
检查代码¶
TorchScript 为所有 ScriptModule
实例提供了一个代码美化打印器。此美化打印器将脚本方法的代码解释为有效的 Python 语法。例如
@torch.jit.script
def foo(len):
# type: (int) -> torch.Tensor
rv = torch.zeros(3, 4)
for i in range(len):
if i < 10:
rv = rv - 1.0
else:
rv = rv + 1.0
return rv
print(foo.code)
具有单个 forward
方法的 ScriptModule
将拥有一个属性 code
,您可以使用它来检查 ScriptModule
的代码。如果 ScriptModule
有多个方法,则需要访问方法本身的 .code
而不是模块的。我们可以通过访问 .foo.code
来检查 ScriptModule
上名为 foo
的方法的代码。上面的示例产生以下输出
def foo(len: int) -> Tensor:
rv = torch.zeros([3, 4], dtype=None, layout=None, device=None, pin_memory=None)
rv0 = rv
for i in range(len):
if torch.lt(i, 10):
rv1 = torch.sub(rv0, 1., 1)
else:
rv1 = torch.add(rv0, 1., 1)
rv0 = rv1
return rv0
这是 TorchScript 对 forward
方法代码的编译结果。您可以使用此结果来确保 TorchScript(跟踪或脚本化)正确地捕获了您的模型代码。
解释图¶
TorchScript 也有比代码美化打印器更低层次的表示形式,即 IR 图。
TorchScript 使用静态单赋值(SSA)中间表示(IR)来表示计算。这种格式的指令由 ATen(PyTorch 的 C++ 后端)算子和其他基本算子组成,包括用于循环和条件的控制流算子。例如
@torch.jit.script
def foo(len):
# type: (int) -> torch.Tensor
rv = torch.zeros(3, 4)
for i in range(len):
if i < 10:
rv = rv - 1.0
else:
rv = rv + 1.0
return rv
print(foo.graph)
关于 forward
方法查找,graph
遵循 检查代码 部分描述的相同规则。
上述示例脚本产生以下图
graph(%len.1 : int):
%24 : int = prim::Constant[value=1]()
%17 : bool = prim::Constant[value=1]() # test.py:10:5
%12 : bool? = prim::Constant()
%10 : Device? = prim::Constant()
%6 : int? = prim::Constant()
%1 : int = prim::Constant[value=3]() # test.py:9:22
%2 : int = prim::Constant[value=4]() # test.py:9:25
%20 : int = prim::Constant[value=10]() # test.py:11:16
%23 : float = prim::Constant[value=1]() # test.py:12:23
%4 : int[] = prim::ListConstruct(%1, %2)
%rv.1 : Tensor = aten::zeros(%4, %6, %6, %10, %12) # test.py:9:10
%rv : Tensor = prim::Loop(%len.1, %17, %rv.1) # test.py:10:5
block0(%i.1 : int, %rv.14 : Tensor):
%21 : bool = aten::lt(%i.1, %20) # test.py:11:12
%rv.13 : Tensor = prim::If(%21) # test.py:11:9
block0():
%rv.3 : Tensor = aten::sub(%rv.14, %23, %24) # test.py:12:18
-> (%rv.3)
block1():
%rv.6 : Tensor = aten::add(%rv.14, %23, %24) # test.py:14:18
-> (%rv.6)
-> (%17, %rv.13)
return (%rv)
以指令 %rv.1 : Tensor = aten::zeros(%4, %6, %6, %10, %12) # test.py:9:10
为例。
%rv.1 : Tensor
表示我们将输出分配给一个(唯一的)名为rv.1
的值,该值的类型为Tensor
,且我们不知道它的具体形状。aten::zeros
是算子(等同于torch.zeros
),输入列表(%4, %6, %6, %10, %12)
指定了范围内的哪些值应作为输入传递。内置函数如aten::zeros
的 schema 可以在 内置函数 中找到。# test.py:9:10
是原始源文件中生成此指令的位置。在本例中,它是一个名为 test.py 的文件,在第 9 行的第 10 个字符处。
注意,算子也可以有关联的 blocks
,即 prim::Loop
和 prim::If
算子。在图的打印输出中,这些算子被格式化以反映其等效的源代码形式,以便于调试。
如图所示,可以检查图以确认由 ScriptModule
描述的计算是正确的,无论是通过自动化还是手动方式,如下所述。
跟踪器¶
跟踪的边缘情况¶
存在一些边缘情况,在这些情况下,给定 Python 函数/模块的跟踪结果将无法代表底层代码。这些情况可能包括
依赖于输入的控制流的跟踪(例如张量形状)
张量视图的就地操作的跟踪(例如赋值语句左侧的索引)
请注意,这些情况实际上可能在将来变得可跟踪。
自动跟踪检查¶
自动捕获跟踪中许多错误的一种方法是在 torch.jit.trace()
API 上使用 check_inputs
。check_inputs
接受一个输入元组列表,这些输入将用于重新跟踪计算并验证结果。例如
def loop_in_traced_fn(x):
result = x[0]
for i in range(x.size(0)):
result = result * x[i]
return result
inputs = (torch.rand(3, 4, 5),)
check_inputs = [(torch.rand(4, 5, 6),), (torch.rand(2, 3, 4),)]
traced = torch.jit.trace(loop_in_traced_fn, inputs, check_inputs=check_inputs)
为我们提供了以下诊断信息
ERROR: Graphs differed across invocations!
Graph diff:
graph(%x : Tensor) {
%1 : int = prim::Constant[value=0]()
%2 : int = prim::Constant[value=0]()
%result.1 : Tensor = aten::select(%x, %1, %2)
%4 : int = prim::Constant[value=0]()
%5 : int = prim::Constant[value=0]()
%6 : Tensor = aten::select(%x, %4, %5)
%result.2 : Tensor = aten::mul(%result.1, %6)
%8 : int = prim::Constant[value=0]()
%9 : int = prim::Constant[value=1]()
%10 : Tensor = aten::select(%x, %8, %9)
- %result : Tensor = aten::mul(%result.2, %10)
+ %result.3 : Tensor = aten::mul(%result.2, %10)
? ++
%12 : int = prim::Constant[value=0]()
%13 : int = prim::Constant[value=2]()
%14 : Tensor = aten::select(%x, %12, %13)
+ %result : Tensor = aten::mul(%result.3, %14)
+ %16 : int = prim::Constant[value=0]()
+ %17 : int = prim::Constant[value=3]()
+ %18 : Tensor = aten::select(%x, %16, %17)
- %15 : Tensor = aten::mul(%result, %14)
? ^ ^
+ %19 : Tensor = aten::mul(%result, %18)
? ^ ^
- return (%15);
? ^
+ return (%19);
? ^
}
此消息向我们表明,计算结果在我们首次跟踪时与使用 check_inputs
跟踪时不同。确实,loop_in_traced_fn
函数体内的循环取决于输入 x
的形状,因此当我们尝试使用不同形状的另一个 x
时,跟踪结果会不同。
在这种情况下,像这样的数据依赖控制流可以使用 torch.jit.script()
来捕获
def fn(x):
result = x[0]
for i in range(x.size(0)):
result = result * x[i]
return result
inputs = (torch.rand(3, 4, 5),)
check_inputs = [(torch.rand(4, 5, 6),), (torch.rand(2, 3, 4),)]
scripted_fn = torch.jit.script(fn)
print(scripted_fn.graph)
#print(str(scripted_fn.graph).strip())
for input_tuple in [inputs] + check_inputs:
torch.testing.assert_close(fn(*input_tuple), scripted_fn(*input_tuple))
这将产生
graph(%x : Tensor) {
%5 : bool = prim::Constant[value=1]()
%1 : int = prim::Constant[value=0]()
%result.1 : Tensor = aten::select(%x, %1, %1)
%4 : int = aten::size(%x, %1)
%result : Tensor = prim::Loop(%4, %5, %result.1)
block0(%i : int, %7 : Tensor) {
%10 : Tensor = aten::select(%x, %1, %i)
%result.2 : Tensor = aten::mul(%7, %10)
-> (%5, %result.2)
}
return (%result);
}
跟踪器警告¶
跟踪器会针对跟踪计算中的几个有问题模式产生警告。例如,跟踪一个对 Tensor 的切片(视图)进行就地赋值的函数
def fill_row_zero(x):
x[0] = torch.rand(*x.shape[1:2])
return x
traced = torch.jit.trace(fill_row_zero, (torch.rand(3, 4),))
print(traced.graph)
产生几个警告和一个简单返回输入的图
fill_row_zero.py:4: TracerWarning: There are 2 live references to the data region being modified when tracing in-place operator copy_ (possibly due to an assignment). This might cause the trace to be incorrect, because all other views that also reference this data will not reflect this change in the trace! On the other hand, if all other views use the same memory chunk, but are disjoint (e.g. are outputs of torch.split), this might still be safe.
x[0] = torch.rand(*x.shape[1:2])
fill_row_zero.py:6: TracerWarning: Output nr 1. of the traced function does not match the corresponding output of the Python function. Detailed error:
Not within tolerance rtol=1e-05 atol=1e-05 at input[0, 1] (0.09115803241729736 vs. 0.6782537698745728) and 3 other locations (33.00%)
traced = torch.jit.trace(fill_row_zero, (torch.rand(3, 4),))
graph(%0 : Float(3, 4)) {
return (%0);
}
我们可以通过修改代码来解决这个问题,不使用就地更新,而是使用 torch.cat
以非就地方式构建结果张量。
def fill_row_zero(x):
x = torch.cat((torch.rand(1, *x.shape[1:2]), x[1:2]), dim=0)
return x
traced = torch.jit.trace(fill_row_zero, (torch.rand(3, 4),))
print(traced.graph)
常见问题¶
问:我想在 GPU 上训练模型并在 CPU 上进行推理。最佳实践是什么?
首先将您的模型从 GPU 转换为 CPU,然后保存它,如下所示
cpu_model = gpu_model.cpu() sample_input_cpu = sample_input_gpu.cpu() traced_cpu = torch.jit.trace(cpu_model, sample_input_cpu) torch.jit.save(traced_cpu, "cpu.pt") traced_gpu = torch.jit.trace(gpu_model, sample_input_gpu) torch.jit.save(traced_gpu, "gpu.pt") # ... later, when using the model: if use_gpu: model = torch.jit.load("gpu.pt") else: model = torch.jit.load("cpu.pt") model(input)推荐这样做是因为跟踪器可能会在特定设备上看到张量创建,因此对已加载的模型进行类型转换可能会产生意外影响。在保存模型之前对其进行类型转换可确保跟踪器具有正确的设备信息。
问:如何在 ScriptModule
上存储属性?
假设我们有以下模型
import torch class Model(torch.nn.Module): def __init__(self): super().__init__() self.x = 2 def forward(self): return self.x m = torch.jit.script(Model())如果实例化
Model
,将导致编译错误,因为编译器不知道x
。有 4 种方法可以告知编译器关于ScriptModule
上的属性1.
nn.Parameter
- 包装在nn.Parameter
中的值将像在nn.Module
中一样工作2.
register_buffer
- 包装在register_buffer
中的值将像在nn.Module
中一样工作。这等效于一个类型为Tensor
的属性(参见 4)。3. 常量 - 将类成员注释为
Final
(或在类定义级别将其添加到名为__constants__
的列表中)将把包含的名称标记为常量。常量直接保存在模型的代码中。有关详细信息,请参阅 内置常量。4. 属性 - 支持类型 的值可以作为可变属性添加。大多数类型可以被推断,但有些可能需要指定,有关详细信息,请参阅 模块属性。
问:我想跟踪模块的方法,但我一直收到此错误
RuntimeError: 无法插入需要梯度的 Tensor 作为常量。考虑将其设为参数或输入,或分离梯度
此错误通常意味着您正在跟踪的方法使用了模块的参数,并且您传递的是模块的方法,而不是模块实例(例如
my_module_instance.forward
对比my_module_instance
)。
使用模块的方法调用
trace
会将模块参数(可能需要梯度)作为 常量 捕获。另一方面,使用模块实例(例如
my_module
)调用trace
会创建一个新模块并将参数正确复制到新模块中,这样如果需要,它们就可以累积梯度。要跟踪模块上的特定方法,请参阅
torch.jit.trace_module
已知问题¶
如果您在 TorchScript 中使用 Sequential
,某些 Sequential
子模块的输入可能会被错误地推断为 Tensor
,即使它们被以其他方式注释。标准解决方案是继承 nn.Sequential
并正确声明输入类型来重新声明 forward
。
附录¶
迁移到 PyTorch 1.2 递归脚本 API¶
本节详细介绍了 PyTorch 1.2 中 TorchScript 的变化。如果您是 TorchScript 的新用户,可以跳过本节。PyTorch 1.2 对 TorchScript API 有两个主要变化。
1. torch.jit.script
现在将尝试递归编译遇到的函数、方法和类。一旦调用 torch.jit.script
,编译就是“选择退出”,而不是“选择加入”。
2. torch.jit.script(nn_module_instance)
现在是创建 ScriptModule
的首选方式,而不是继承自 torch.jit.ScriptModule
。这些变化结合起来,提供了一个更简单、更易于使用的 API,用于将您的 nn.Module
转换为 ScriptModule
,以便在非 Python 环境中进行优化和执行。
新用法如下所示
import torch
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
my_model = Model()
my_scripted_model = torch.jit.script(my_model)
模块的
forward
默认会被编译。从forward
中调用的方法会根据其在forward
中的使用顺序被惰性编译。要编译除了
forward
之外且未从forward
调用的方法,请添加@torch.jit.export
。要阻止编译器编译某个方法,请添加
@torch.jit.ignore
或@torch.jit.unused
。@ignore
会将该方法保留为对 python 的调用,而
@unused
则将其替换为异常。@ignored
无法导出;@unused
可以。大多数属性类型可以被推断,因此
torch.jit.Attribute
不是必需的。对于空容器类型,请使用 PEP 526 风格的类注解来注解其类型。常量可以使用
Final
类注解进行标记,而不是将成员名称添加到__constants__
。可以使用 Python 3 类型提示来代替
torch.jit.annotate
- 由于这些变化,以下项被视为已弃用,不应出现在新代码中
继承自
torch.jit.ScriptModule
的类The
torch.jit.Attribute
wrapper classThe
__constants__
arrayThe
torch.jit.annotate
function
模块¶
警告
@torch.jit.ignore
注解的行为在 PyTorch 1.2 中发生了变化。在 PyTorch 1.2 之前,@ignore 装饰器用于使一个函数或方法可以从导出的代码中调用。要重新获得此功能,请使用 @torch.jit.unused()
。@torch.jit.ignore
现在等同于 @torch.jit.ignore(drop=False)
。详情请参阅 @torch.jit.ignore
和 @torch.jit.unused
。
当传递给 torch.jit.script
函数时,torch.nn.Module
的数据会被复制到一个 ScriptModule
,并且 TorchScript 编译器会编译该模块。模块的 forward
方法默认会被编译。从 forward
中调用的方法会按照它们在 forward
中被使用的顺序进行延迟编译,任何带有 @torch.jit.export
注解的方法也会被编译。
- torch.jit.export(fn)[source][source]¶
此装饰器指示
nn.Module
上的一个方法用作ScriptModule
的入口点,并且应该被编译。forward
隐式地被假定为入口点,因此它不需要此装饰器。从forward
中调用的函数和方法会在编译器看到它们时进行编译,因此它们也不需要此装饰器。示例(在方法上使用
@torch.jit.export
)import torch import torch.nn as nn class MyModule(nn.Module): def implicitly_compiled_method(self, x): return x + 99 # `forward` is implicitly decorated with `@torch.jit.export`, # so adding it here would have no effect def forward(self, x): return x + 10 @torch.jit.export def another_forward(self, x): # When the compiler sees this call, it will compile # `implicitly_compiled_method` return self.implicitly_compiled_method(x) def unused_method(self, x): return x - 20 # `m` will contain compiled methods: # `forward` # `another_forward` # `implicitly_compiled_method` # `unused_method` will not be compiled since it was not called from # any compiled methods and wasn't decorated with `@torch.jit.export` m = torch.jit.script(MyModule())
函数¶
函数变化不大,如果需要,可以使用 @torch.jit.ignore
或 torch.jit.unused
进行装饰。
# Same behavior as pre-PyTorch 1.2
@torch.jit.script
def some_fn():
return 2
# Marks a function as ignored, if nothing
# ever calls it then this has no effect
@torch.jit.ignore
def some_fn2():
return 2
# As with ignore, if nothing calls it then it has no effect.
# If it is called in script it is replaced with an exception.
@torch.jit.unused
def some_fn3():
import pdb; pdb.set_trace()
return 4
# Doesn't do anything, this function is already
# the main entry point
@torch.jit.export
def some_fn4():
return 2
TorchScript 类¶
警告
TorchScript 类的支持是实验性的。目前它最适合简单的记录类型(可以认为是附带方法的 NamedTuple
)。
用户自定义 TorchScript 类中的所有内容默认都会被导出,如果需要,函数可以使用 @torch.jit.ignore
进行装饰。
属性¶
TorchScript 编译器需要知道 模块属性 的类型。大多数类型可以从成员的值推断出来。空列表和字典无法推断其类型,必须使用 PEP 526 风格 的类注解来指定其类型。如果无法推断类型且未显式注解,它将不会作为属性添加到生成的 ScriptModule
中。
旧 API
from typing import Dict
import torch
class MyModule(torch.jit.ScriptModule):
def __init__(self):
super().__init__()
self.my_dict = torch.jit.Attribute({}, Dict[str, int])
self.my_int = torch.jit.Attribute(20, int)
m = MyModule()
新 API
from typing import Dict
class MyModule(torch.nn.Module):
my_dict: Dict[str, int]
def __init__(self):
super().__init__()
# This type cannot be inferred and must be specified
self.my_dict = {}
# The attribute type here is inferred to be `int`
self.my_int = 20
def forward(self):
pass
m = torch.jit.script(MyModule())
常量¶
Final
类型构造函数可用于将成员标记为 常量。如果成员未标记为常量,它们将被复制到生成的 ScriptModule
中作为属性。使用 Final
可以为已知固定值的成员提供优化机会,并增加额外的类型安全性。
旧 API
class MyModule(torch.jit.ScriptModule):
__constants__ = ['my_constant']
def __init__(self):
super().__init__()
self.my_constant = 2
def forward(self):
pass
m = MyModule()
新 API
from typing import Final
class MyModule(torch.nn.Module):
my_constant: Final[int]
def __init__(self):
super().__init__()
self.my_constant = 2
def forward(self):
pass
m = torch.jit.script(MyModule())
变量¶
容器被假定为 Tensor
类型且非可选(更多信息请参阅 默认类型)。之前,torch.jit.annotate
用于告知 TorchScript 编译器类型应该是什么。现在支持 Python 3 风格的类型提示。
import torch
from typing import Dict, Optional
@torch.jit.script
def make_dict(flag: bool):
x: Dict[str, int] = {}
x['hi'] = 2
b: Optional[int] = None
if flag:
b = 2
return x, b
融合后端¶
有几种融合后端可用于优化 TorchScript 执行。在 CPU 上默认的 fuser 是 NNC,它可以为 CPU 和 GPU 执行融合。在 GPU 上默认的 fuser 是 NVFuser,它支持更广泛的运算符,并已证明生成的内核具有更高的吞吐量。有关使用和调试的更多详细信息,请参阅 NVFuser 文档。