TorchScript 语言参考¶
TorchScript 是 Python 的静态类型子集,可以直接编写(使用 @torch.jit.script
装饰器),也可以通过追踪从 Python 代码自动生成。当使用追踪时,代码会被自动转换为 Python 的这个子集,方法是只记录张量上的实际操作符,并简单地执行和丢弃其他周围的 Python 代码。
当直接使用 @torch.jit.script
装饰器编写 TorchScript 时,程序员必须仅使用 TorchScript 中支持的 Python 子集。本节记录了 TorchScript 中支持的内容,如同它是一个独立语言的语言参考。本参考中未提及的任何 Python 功能都不是 TorchScript 的一部分。有关可用的 PyTorch 张量方法、模块和函数的完整参考,请参阅内置函数。
作为 Python 的子集,任何有效的 TorchScript 函数也是一个有效的 Python 函数。这使得可以禁用 TorchScript 并使用标准的 Python 工具(如 pdb
)调试函数。反之则不然:有很多有效的 Python 程序不是有效的 TorchScript 程序。相反,TorchScript 专门关注 Python 中表示 PyTorch 中的神经网络模型所需的功能。
类型¶
TorchScript 和完整 Python 语言之间最大的区别在于,TorchScript 仅支持表达神经网络模型所需的一小部分类型。特别是,TorchScript 支持
类型 |
描述 |
---|---|
|
任何 dtype、维度或后端的 PyTorch 张量 |
|
包含子类型 |
|
一个布尔值 |
|
一个标量整数 |
|
一个标量浮点数 |
|
一个字符串 |
|
所有成员均为 |
|
一个可以是 None 或 T 类型的值 |
|
一个键类型为 |
|
|
|
|
|
一个 |
|
子类型 |
与 Python 不同,TorchScript 函数中的每个变量都必须具有单一的静态类型。这使得优化 TorchScript 函数更容易。
示例(类型不匹配)
import torch
@torch.jit.script
def an_error(x):
if x:
r = torch.rand(1)
else:
r = 4
return r
Traceback (most recent call last):
...
RuntimeError: ...
Type mismatch: r is set to type Tensor in the true branch and type int in the false branch:
@torch.jit.script
def an_error(x):
if x:
~~~~~
r = torch.rand(1)
~~~~~~~~~~~~~~~~~
else:
~~~~~
r = 4
~~~~~ <--- HERE
return r
and was used here:
else:
r = 4
return r
~ <--- HERE...
不支持的类型构造¶
TorchScript 不支持 typing
模块的所有特性和类型。其中一些是更基本的东西,不太可能在未来添加,而另一些可能会在用户需求足够使其成为优先事项时添加。
来自 typing
模块的以下类型和功能在 TorchScript 中不可用。
项 |
描述 |
---|---|
|
|
未实现 |
|
未实现 |
|
未实现 |
|
未实现 |
|
未实现 |
|
这支持模块属性类属性注释,但不支持函数 |
|
TorchScript 不支持 |
|
|
|
类型别名 |
未实现 |
标称子类型 vs 结构子类型 |
标称类型正在开发中,但结构类型不是 |
NewType |
不太可能实现 |
泛型 |
不太可能实现 |
typing 模块中未在本文档中明确列出的任何其他功能均不受支持。
默认类型¶
默认情况下,TorchScript 函数的所有参数都假定为 Tensor。要指定 TorchScript 函数的参数是另一种类型,可以使用 MyPy 风格的类型注解,使用上面列出的类型。
import torch
@torch.jit.script
def foo(x, tup):
# type: (int, Tuple[Tensor, Tensor]) -> Tensor
t0, t1 = tup
return t0 + t1 + x
print(foo(3, (torch.rand(3), torch.rand(3))))
注意
也可以使用 typing 模块中的 Python 3 类型提示来注解类型。
import torch
from typing import Tuple
@torch.jit.script
def foo(x: int, tup: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
t0, t1 = tup
return t0 + t1 + x
print(foo(3, (torch.rand(3), torch.rand(3))))
空列表假定为 List[Tensor]
,空字典假定为 Dict[str, Tensor]
。要实例化其他类型的空列表或字典,请使用Python 3 类型提示。
示例(Python 3 的类型注解)
import torch
import torch.nn as nn
from typing import Dict, List, Tuple
class EmptyDataStructures(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x: torch.Tensor) -> Tuple[List[Tuple[int, float]], Dict[str, int]]:
# This annotates the list to be a `List[Tuple[int, float]]`
my_list: List[Tuple[int, float]] = []
for i in range(10):
my_list.append((i, x.item()))
my_dict: Dict[str, int] = {}
return my_list, my_dict
x = torch.jit.script(EmptyDataStructures())
可选类型精细化¶
当在 if 语句的条件内部进行与 None
的比较或在 assert
中检查时,TorchScript 将精细化 Optional[T]
类型变量的类型。编译器可以推理使用 and
、or
和 not
组合的多个 None
检查。对于未显式编写的 if 语句的 else 块,也会发生精细化。
None
检查必须在 if 语句的条件内部;将 None
检查分配给变量并在 if 语句的条件中使用它不会精细化检查中变量的类型。只会精细化局部变量,像 self.x
这样的属性不会,并且必须分配给局部变量才能精细化。
示例(精细化参数和局部变量的类型)
import torch
import torch.nn as nn
from typing import Optional
class M(nn.Module):
z: Optional[int]
def __init__(self, z):
super().__init__()
# If `z` is None, its type cannot be inferred, so it must
# be specified (above)
self.z = z
def forward(self, x, y, z):
# type: (Optional[int], Optional[int], Optional[int]) -> int
if x is None:
x = 1
x = x + 1
# Refinement for an attribute by assigning it to a local
z = self.z
if y is not None and z is not None:
x = y + z
# Refinement via an `assert`
assert z is not None
x += z
return x
module = torch.jit.script(M(2))
module = torch.jit.script(M(None))
TorchScript 类¶
警告
TorchScript 类支持是实验性的。目前,它最适合简单的记录型类型(想想附加了方法的 NamedTuple
)。
如果 Python 类使用 @torch.jit.script
注解,则可以在 TorchScript 中使用,类似于声明 TorchScript 函数的方式
@torch.jit.script
class Foo:
def __init__(self, x, y):
self.x = x
def aug_add_x(self, inc):
self.x += inc
此子集受到限制
所有函数都必须是有效的 TorchScript 函数(包括
__init__()
)。类必须是新式类,因为我们使用
__new__()
和 pybind11 来构造它们。TorchScript 类是静态类型的。成员只能通过在
__init__()
方法中赋值给 self 来声明。例如,在
__init__()
方法外部赋值给self
@torch.jit.script class Foo: def assign_x(self): self.x = torch.rand(2, 3)
将导致
RuntimeError: Tried to set nonexistent attribute: x. Did you forget to initialize it in __init__()?: def assign_x(self): self.x = torch.rand(2, 3) ~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
除了方法定义之外,类的主体中不允许任何表达式。
不支持继承或任何其他多态策略,除非从
object
继承以指定新式类。
定义类后,它可以像任何其他 TorchScript 类型一样在 TorchScript 和 Python 中互换使用
# Declare a TorchScript class
@torch.jit.script
class Pair:
def __init__(self, first, second):
self.first = first
self.second = second
@torch.jit.script
def sum_pair(p):
# type: (Pair) -> Tensor
return p.first + p.second
p = Pair(torch.rand(2, 3), torch.rand(2, 3))
print(sum_pair(p))
TorchScript 枚举¶
Python 枚举可以在 TorchScript 中使用,无需任何额外的注解或代码
from enum import Enum
class Color(Enum):
RED = 1
GREEN = 2
@torch.jit.script
def enum_fn(x: Color, y: Color) -> bool:
if x == Color.RED:
return True
return x == y
定义枚举后,它可以像任何其他 TorchScript 类型一样在 TorchScript 和 Python 中互换使用。枚举值的类型必须是 int
、float
或 str
。所有值必须是相同的类型;不支持枚举值的异构类型。
命名元组¶
collections.namedtuple
生成的类型可以在 TorchScript 中使用。
import torch
import collections
Point = collections.namedtuple('Point', ['x', 'y'])
@torch.jit.script
def total(point):
# type: (Point) -> Tensor
return point.x + point.y
p = Point(x=torch.rand(3), y=torch.rand(3))
print(total(p))
可迭代对象¶
某些函数(例如,zip
和 enumerate
)只能对可迭代类型进行操作。TorchScript 中的可迭代类型包括 Tensor
s、列表、元组、字典、字符串、torch.nn.ModuleList
和 torch.nn.ModuleDict
。
表达式¶
支持以下 Python 表达式。
字面量¶
True
False
None
'string literals'
"string literals"
3 # interpreted as int
3.4 # interpreted as a float
列表构造¶
空列表假定为 List[Tensor]
类型。其他列表字面量的类型从成员的类型派生。有关更多详细信息,请参阅默认类型。
[3, 4]
[]
[torch.rand(3), torch.rand(4)]
元组构造¶
(3, 4)
(3,)
算术运算符¶
a + b
a - b
a * b
a / b
a ^ b
a @ b
比较运算符¶
a == b
a != b
a < b
a > b
a <= b
a >= b
逻辑运算符¶
a and b
a or b
not b
下标和切片¶
t[0]
t[-1]
t[0:2]
t[1:]
t[:1]
t[:]
t[0, 1]
t[0, 1:2]
t[0, :1]
t[-1, 1:, 0]
t[1:, -1, 0]
t[i:j, i]
函数调用¶
调用内置函数
torch.rand(3, dtype=torch.int)
调用其他脚本函数
import torch
@torch.jit.script
def foo(x):
return x + 1
@torch.jit.script
def bar(x):
return foo(x)
方法调用¶
调用内置类型的方法,如张量:x.mm(y)
在模块上,方法必须先编译才能被调用。TorchScript 编译器在编译其他方法时会递归编译它看到的方法。默认情况下,编译从 forward
方法开始。forward
调用的任何方法都将被编译,以及这些方法调用的任何方法,依此类推。要在 forward
之外的方法开始编译,请使用 @torch.jit.export
装饰器(forward
隐式标记为 @torch.jit.export
)。
直接调用子模块(例如 self.resnet(input)
)等同于调用其 forward
方法(例如 self.resnet.forward(input)
)。
import torch
import torch.nn as nn
import torchvision
class MyModule(nn.Module):
def __init__(self):
super().__init__()
means = torch.tensor([103.939, 116.779, 123.68])
self.means = torch.nn.Parameter(means.resize_(1, 3, 1, 1))
resnet = torchvision.models.resnet18()
self.resnet = torch.jit.trace(resnet, torch.rand(1, 3, 224, 224))
def helper(self, input):
return self.resnet(input - self.means)
def forward(self, input):
return self.helper(input)
# Since nothing in the model calls `top_level_method`, the compiler
# must be explicitly told to compile this method
@torch.jit.export
def top_level_method(self, input):
return self.other_helper(input)
def other_helper(self, input):
return input + 10
# `my_script_module` will have the compiled methods `forward`, `helper`,
# `top_level_method`, and `other_helper`
my_script_module = torch.jit.script(MyModule())
三元表达式¶
x if x > y else y
类型转换¶
float(ten)
int(3.5)
bool(ten)
str(2)``
访问模块参数¶
self.my_parameter
self.my_submodule.my_parameter
语句¶
TorchScript 支持以下类型的语句
简单赋值¶
a = b
a += b # short-hand for a = a + b, does not operate in-place on a
a -= b
打印语句¶
print("the result of an add:", a + b)
If 语句¶
if a < 4:
r = -a
elif a < 3:
r = a + a
else:
r = 3 * a
除了布尔值之外,浮点数、整数和张量也可以在条件中使用,并将被隐式转换为布尔值。
While 循环¶
a = 0
while a < 4:
print(a)
a += 1
带有 range 的 For 循环¶
x = 0
for i in range(10):
x *= i
遍历元组的 For 循环¶
这些循环展开,为元组的每个成员生成一个主体。主体必须为每个成员正确进行类型检查。
tup = (3, torch.rand(4))
for x in tup:
print(x)
遍历常量 nn.ModuleList 的 For 循环¶
要在编译方法内部使用 nn.ModuleList
,必须通过将属性名称添加到类型的 __constants__
列表中来将其标记为常量。遍历 nn.ModuleList
的 For 循环将在编译时展开循环体,其中包含常量模块列表的每个成员。
class SubModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.weight = nn.Parameter(torch.randn(2))
def forward(self, input):
return self.weight + input
class MyModule(torch.nn.Module):
__constants__ = ['mods']
def __init__(self):
super().__init__()
self.mods = torch.nn.ModuleList([SubModule() for i in range(10)])
def forward(self, v):
for module in self.mods:
v = module(v)
return v
m = torch.jit.script(MyModule())
Break 和 Continue¶
for i in range(5):
if i == 1:
continue
if i == 3:
break
print(i)
Return¶
return a, b
变量解析¶
TorchScript 支持 Python 变量解析(即作用域)规则的子集。局部变量的行为与 Python 中相同,但有一个限制,即变量在函数的所有路径中必须具有相同的类型。如果变量在 if 语句的不同分支上具有不同的类型,则在 if 语句结束后使用它会出错。
同样,如果变量仅在函数的部分路径中<莲花>定义,则不允许使用该变量。
示例
@torch.jit.script
def foo(x):
if x < 0:
y = 4
print(y)
Traceback (most recent call last):
...
RuntimeError: ...
y is not defined in the false branch...
@torch.jit.script...
def foo(x):
if x < 0:
~~~~~~~~~
y = 4
~~~~~ <--- HERE
print(y)
and was used here:
if x < 0:
y = 4
print(y)
~ <--- HERE...
非局部变量在定义函数时在编译时解析为 Python 值。然后使用 Python 值的用法中描述的规则将这些值转换为 TorchScript 值。
Python 值的用法¶
为了使编写 TorchScript 更加方便,我们允许脚本代码引用周围作用域中的 Python 值。例如,每当引用 torch
时,TorchScript 编译器实际上会在声明函数时将其解析为 torch
Python 模块。这些 Python 值不是 TorchScript 的第一类部分。相反,它们在编译时被反糖化为 TorchScript 支持的原始类型。这取决于编译发生时引用的 Python 值的动态类型。本节描述了在 TorchScript 中访问 Python 值时使用的规则。
函数¶
TorchScript 可以调用 Python 函数。当逐步将模型转换为 TorchScript 时,此功能非常有用。模型可以逐函数移动到 TorchScript,同时保留对 Python 函数的调用。这样,您可以逐步检查模型的正确性。
- torch.jit.is_scripting()[source][source]¶
当处于编译中时返回 True,否则返回 False 的函数。这在使用 @unused 装饰器来保留模型中尚未与 TorchScript 兼容的代码时特别有用。 .. testcode
import torch @torch.jit.unused def unsupported_linear_op(x): return x def linear(x): if torch.jit.is_scripting(): return torch.linear(x) else: return unsupported_linear_op(x)
- 返回类型
Python 模块上的属性查找¶
TorchScript 可以在模块上查找属性。内置函数(如 torch.add
)就是通过这种方式访问的。这允许 TorchScript 调用在其他模块中定义的函数。
Python 定义的常量¶
TorchScript 还提供了一种使用在 Python 中定义的常量的方法。这些常量可用于将超参数硬编码到函数中,或定义通用常量。有两种方法可以指定应将 Python 值视为常量。
作为模块属性查找的值被假定为常量
import math
import torch
@torch.jit.script
def fn():
return math.pi
ScriptModule 的属性可以通过使用
Final[T]
注解来标记为常量
import torch
import torch.nn as nn
class Foo(nn.Module):
# `Final` from the `typing_extensions` module can also be used
a : torch.jit.Final[int]
def __init__(self):
super().__init__()
self.a = 1 + 4
def forward(self, input):
return self.a + input
f = torch.jit.script(Foo())
支持的常量 Python 类型有
整数
浮点数
布尔值
torch.device
torch.layout
torch.dtype
包含支持类型的元组
torch.nn.ModuleList
,可以在 TorchScript for 循环中使用
模块属性¶
torch.nn.Parameter
包装器和 register_buffer
可用于将张量分配给模块。分配给已编译模块的其他值将在其类型可以推断的情况下添加到已编译模块。TorchScript 中可用的所有类型都可以用作模块属性。张量属性在语义上与缓冲区相同。空列表、字典和 None
值的类型无法推断,必须通过 PEP 526 风格的类注解来指定。如果类型无法推断且未显式注解,则不会将其作为属性添加到生成的 ScriptModule
中。
示例
from typing import List, Dict
class Foo(nn.Module):
# `words` is initialized as an empty list, so its type must be specified
words: List[str]
# The type could potentially be inferred if `a_dict` (below) was not
# empty, but this annotation ensures `some_dict` will be made into the
# proper type
some_dict: Dict[str, int]
def __init__(self, a_dict):
super().__init__()
self.words = []
self.some_dict = a_dict
# `int`s can be inferred
self.my_int = 10
def forward(self, input):
# type: (str) -> int
self.words.append(input)
return self.some_dict[input] + self.my_int
f = torch.jit.script(Foo({'hi': 2}))