TorchScript 语言参考¶
本参考手册介绍了 TorchScript 语言的语法和核心语义。TorchScript 是 Python 语言的静态类型子集。本文档解释了 TorchScript 中支持的 Python 功能,以及该语言与常规 Python 的不同之处。本参考手册中未提及的任何 Python 功能都不属于 TorchScript。TorchScript 专注于 Python 中用于表示 PyTorch 中神经网络模型的功能。
术语¶
本文档使用以下术语
模式 |
备注 |
---|---|
|
表示给定符号被定义为。 |
|
表示语法中包含的实际关键字和分隔符。 |
|
表示 A 或 B。 |
|
表示分组。 |
|
表示可选。 |
|
表示正则表达式,其中术语 A 至少重复一次。 |
|
表示正则表达式,其中术语 A 重复零次或多次。 |
类型系统¶
TorchScript 是 Python 的静态类型子集。TorchScript 与完整 Python 语言之间最大的区别在于,TorchScript 只支持一小部分用于表达神经网络模型的类型。
TorchScript 类型¶
TorchScript 类型系统由 TSType
和 TSModuleType
组成,定义如下。
TSAllType ::= TSType | TSModuleType
TSType ::= TSMetaType | TSPrimitiveType | TSStructuralType | TSNominalType
TSType
表示大多数 TorchScript 类型,这些类型是可组合的,并且可以在 TorchScript 类型注解中使用。 TSType
指代以下任何一项
元类型,例如
Any
原始类型,例如
int
、float
和str
结构类型,例如
Optional[int]
或List[MyClass]
名义类型(Python 类),例如
MyClass
(用户定义)和torch.tensor
(内置)
TSModuleType
表示 torch.nn.Module
及其子类。它与 TSType
的处理方式不同,因为它 的类型模式部分是从对象实例推断的,部分是从类定义推断的。因此,TSModuleType
的实例可能不遵循相同的静态类型模式。出于类型安全的考虑,TSModuleType
不能用作 TorchScript 类型注解,也不能与 TSType
组合。
元类型¶
元类型非常抽象,更像是类型约束,而不是具体类型。目前 TorchScript 定义了一个元类型,Any
,它表示任何 TorchScript 类型。
Any
类型¶
Any
类型表示任何 TorchScript 类型。 Any
不指定任何类型约束,因此不会对 Any
进行类型检查。因此,它可以绑定到任何 Python 或 TorchScript 数据类型(例如,int
、TorchScript tuple
或任意未脚本化的 Python 类)。
TSMetaType ::= "Any"
在哪里
Any
是来自 typing 模块的 Python 类名。因此,要使用Any
类型,您必须从typing
中导入它(例如,from typing import Any
)。由于
Any
可以表示任何 TorchScript 类型,因此允许在Any
上对该类型的值进行操作的操作符集是有限的。
支持 Any
类型的操作符¶
将数据分配给
Any
类型的变量。绑定到
Any
类型的参数或返回值。x is
,x is not
,其中x
为Any
类型。isinstance(x, Type)
,其中x
为Any
类型。Any
类型的可打印。List[Any]
类型的可排序,如果数据是相同类型T
的值的列表,并且T
支持比较运算符。
与 Python 相比
Any
是 TorchScript 类型系统中最不严格的类型。从这个意义上说,它与 Python 中的 Object
类非常相似。但是,Any
只支持 Object
所支持的操作符和方法的一个子集。
设计说明¶
当我们对 PyTorch 模块进行脚本化时,可能会遇到与脚本执行无关的数据。然而,它必须由类型模式描述。不仅为未使用的(在脚本上下文中)数据描述静态类型很麻烦,而且还可能导致不必要的脚本化失败。Any
被用来描述在编译时不需要精确静态类型的那些数据的类型。
示例 1
此示例说明了如何使用 Any
允许元组参数的第二个元素为任何类型。这是可能的,因为 x[1]
不涉及任何需要知道其精确类型的计算。
import torch
from typing import Tuple
from typing import Any
@torch.jit.export
def inc_first_element(x: Tuple[int, Any]):
return (x[0]+1, x[1])
m = torch.jit.script(inc_first_element)
print(m((1,2.0)))
print(m((1,(100,200))))
上面的示例产生以下输出
(2, 2.0)
(2, (100, 200))
元组的第二个元素是 Any
类型,因此可以绑定到多种类型。例如,(1, 2.0)
将 float 类型绑定到 Any
,如 Tuple[int, Any]
中,而 (1, (100, 200))
在第二次调用中将元组绑定到 Any
。
示例 2
此示例说明了如何使用 isinstance
来动态检查注释为 Any
类型的数据的类型
import torch
from typing import Any
def f(a:Any):
print(a)
return (isinstance(a, torch.Tensor))
ones = torch.ones([2])
m = torch.jit.script(f)
print(m(ones))
上面的示例产生以下输出
1
1
[ CPUFloatType{2} ]
True
原始类型¶
原始 TorchScript 类型是指表示单个值类型的类型,并带有单个预定义类型名称。
TSPrimitiveType ::= "int" | "float" | "double" | "complex" | "bool" | "str" | "None"
结构类型¶
结构类型是指在结构上定义的类型,没有用户定义的名称(与名义类型不同),例如 Future[int]
。结构类型可以用任何 TSType
组合。
TSStructuralType ::= TSTuple | TSNamedTuple | TSList | TSDict |
TSOptional | TSUnion | TSFuture | TSRRef | TSAwait
TSTuple ::= "Tuple" "[" (TSType ",")* TSType "]"
TSNamedTuple ::= "namedtuple" "(" (TSType ",")* TSType ")"
TSList ::= "List" "[" TSType "]"
TSOptional ::= "Optional" "[" TSType "]"
TSUnion ::= "Union" "[" (TSType ",")* TSType "]"
TSFuture ::= "Future" "[" TSType "]"
TSRRef ::= "RRef" "[" TSType "]"
TSAwait ::= "Await" "[" TSType "]"
TSDict ::= "Dict" "[" KeyType "," TSType "]"
KeyType ::= "str" | "int" | "float" | "bool" | TensorType | "Any"
在哪里
Tuple
、List
、Optional
、Union
、Future
、Dict
代表在模块typing
中定义的 Python 类型类名称。要使用这些类型名称,必须从typing
中导入它们(例如,from typing import Tuple
)。namedtuple
代表 Python 类collections.namedtuple
或typing.NamedTuple
。Future
和RRef
代表 Python 类torch.futures
和torch.distributed.rpc
。Await
代表 Python 类torch._awaits._Await
与 Python 相比
除了与 TorchScript 类型组合之外,这些 TorchScript 结构类型通常还支持其 Python 对应物的操作符和方法的常用子集。
示例 1
此示例使用 typing.NamedTuple
语法定义元组
import torch
from typing import NamedTuple
from typing import Tuple
class MyTuple(NamedTuple):
first: int
second: int
def inc(x: MyTuple) -> Tuple[int, int]:
return (x.first+1, x.second+1)
t = MyTuple(first=1, second=2)
scripted_inc = torch.jit.script(inc)
print("TorchScript:", scripted_inc(t))
上面的示例产生以下输出
TorchScript: (2, 3)
示例 2
此示例使用 collections.namedtuple
语法定义元组
import torch
from typing import NamedTuple
from typing import Tuple
from collections import namedtuple
_AnnotatedNamedTuple = NamedTuple('_NamedTupleAnnotated', [('first', int), ('second', int)])
_UnannotatedNamedTuple = namedtuple('_NamedTupleAnnotated', ['first', 'second'])
def inc(x: _AnnotatedNamedTuple) -> Tuple[int, int]:
return (x.first+1, x.second+1)
m = torch.jit.script(inc)
print(inc(_UnannotatedNamedTuple(1,2)))
上面的示例产生以下输出
(2, 3)
示例 3
此示例说明了注释结构类型的常见错误,即没有从 typing
模块导入组合类型类
import torch
# ERROR: Tuple not recognized because not imported from typing
@torch.jit.export
def inc(x: Tuple[int, int]):
return (x[0]+1, x[1]+1)
m = torch.jit.script(inc)
print(m((1,2)))
运行上述代码会产生以下脚本化错误
File "test-tuple.py", line 5, in <module>
def inc(x: Tuple[int, int]):
NameError: name 'Tuple' is not defined
解决方法是在代码开头添加 from typing import Tuple
一行。
名义类型¶
名义 TorchScript 类型是指 Python 类。这些类型被称为名义类型,因为它们是用自定义名称声明的,并且使用类名进行比较。名义类进一步分为以下类别
TSNominalType ::= TSBuiltinClasses | TSCustomClass | TSEnum
其中,TSCustomClass
和 TSEnum
必须可编译为 TorchScript 中间表示(IR)。这是由类型检查器强制执行的。
内置类¶
内置名义类型是指其语义内置于 TorchScript 系统中的 Python 类(例如,张量类型)。TorchScript 定义了这些内置名义类型的语义,并且通常只支持其 Python 类定义的方法或属性的一个子集。
TSBuiltinClass ::= TSTensor | "torch.device" | "torch.Stream" | "torch.dtype" |
"torch.nn.ModuleList" | "torch.nn.ModuleDict" | ...
TSTensor ::= "torch.Tensor" | "common.SubTensor" | "common.SubWithTorchFunction" |
"torch.nn.parameter.Parameter" | and subclasses of torch.Tensor
关于 torch.nn.ModuleList 和 torch.nn.ModuleDict 的特别说明¶
虽然 torch.nn.ModuleList
和 torch.nn.ModuleDict
在 Python 中被定义为列表和字典,但在 TorchScript 中更像是元组
在 TorchScript 中,
torch.nn.ModuleList
或torch.nn.ModuleDict
的实例是不可变的。遍历
torch.nn.ModuleList
或torch.nn.ModuleDict
的代码是完全展开的,因此torch.nn.ModuleList
的元素或torch.nn.ModuleDict
的键可以是torch.nn.Module
的不同子类。
示例
以下示例突出显示了一些内置 Torchscript 类(torch.*
)的使用
import torch
@torch.jit.script
class A:
def __init__(self):
self.x = torch.rand(3)
def f(self, y: torch.device):
return self.x.to(device=y)
def g():
a = A()
return a.f(torch.device("cpu"))
script_g = torch.jit.script(g)
print(script_g.graph)
自定义类¶
与内置类不同,自定义类的语义是用户定义的,整个类定义必须可编译为 TorchScript IR 并且符合 TorchScript 类型检查规则。
TSClassDef ::= [ "@torch.jit.script" ]
"class" ClassName [ "(object)" ] ":"
MethodDefinition |
[ "@torch.jit.ignore" ] | [ "@torch.jit.unused" ]
MethodDefinition
在哪里
类必须是新式类。Python 3 只支持新式类。在 Python 2.x 中,新式类是通过继承自对象来指定的。
实例数据属性是静态类型的,并且实例属性必须在
__init__()
方法中通过赋值来声明。不支持方法重载(即,不能有多个具有相同方法名的方法)。
MethodDefinition
必须可编译为 TorchScript IR 并且符合 TorchScript 的类型检查规则(即,所有方法都必须是有效的 TorchScript 函数,并且类属性定义必须是有效的 TorchScript 语句)。torch.jit.ignore
和torch.jit.unused
可用于忽略不能完全 torchscriptable 或应被编译器忽略的方法或函数。
与 Python 相比
与它们的 Python 对应物相比,TorchScript 自定义类非常有限。Torchscript 自定义类
不支持类属性。
不支持子类化,除了子类化接口类型或对象。
不支持方法重载。
必须在
__init__()
中初始化其所有实例属性;这是因为 TorchScript 通过在__init__()
中推断属性类型来构建类的静态模式。必须只包含满足 TorchScript 类型检查规则并可编译为 TorchScript IR 的方法。
示例 1
如果 Python 类用 @torch.jit.script
进行了注释,则可以在 TorchScript 中使用它们,类似于声明 TorchScript 函数的方式
@torch.jit.script
class MyClass:
def __init__(self, x: int):
self.x = x
def inc(self, val: int):
self.x += val
示例 2
TorchScript 自定义类类型必须通过在 __init__()
中赋值来“声明”其所有实例属性。如果在 __init__()
中没有定义实例属性,但在类的其他方法中被访问,则该类不能被编译为 TorchScript 类,如下例所示
import torch
@torch.jit.script
class foo:
def __init__(self):
self.y = 1
# ERROR: self.x is not defined in __init__
def assign_x(self):
self.x = torch.rand(2, 3)
该类将无法编译并发出以下错误
RuntimeError:
Tried to set nonexistent attribute: x. Did you forget to initialize it in __init__()?:
def assign_x(self):
self.x = torch.rand(2, 3)
~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
示例 3
在此示例中,TorchScript 自定义类定义了一个类变量名,这是不允许的
import torch
@torch.jit.script
class MyClass(object):
name = "MyClass"
def __init__(self, x: int):
self.x = x
def fn(a: MyClass):
return a.name
它会导致以下编译时错误
RuntimeError:
'__torch__.MyClass' object has no attribute or method 'name'. Did you forget to initialize an attribute in __init__()?:
File "test-class2.py", line 10
def fn(a: MyClass):
return a.name
~~~~~~ <--- HERE
枚举类型¶
与自定义类一样,枚举类型的语义是用户定义的,整个类定义必须可编译为 TorchScript IR 并且符合 TorchScript 类型检查规则。
TSEnumDef ::= "class" Identifier "(enum.Enum | TSEnumType)" ":"
( MemberIdentifier "=" Value )+
( MethodDefinition )*
在哪里
值必须是
int
、float
或str
类型的 TorchScript 字面量,并且必须是相同 TorchScript 类型。TSEnumType
是 TorchScript 枚举类型的名称。与 Python 枚举类似,TorchScript 允许有限的Enum
子类化,即,只有在没有定义任何成员的情况下,才允许子类化枚举。
与 Python 相比
TorchScript 仅支持
enum.Enum
。 它不支持其他变体,例如enum.IntEnum
,enum.Flag
,enum.IntFlag
和enum.auto
。TorchScript 枚举成员的值必须是相同类型,并且只能是
int
,float
或str
类型,而 Python 枚举成员可以是任何类型。包含方法的枚举在 TorchScript 中会被忽略。
示例 1
以下示例定义了类 Color
作为 Enum
类型。
import torch
from enum import Enum
class Color(Enum):
RED = 1
GREEN = 2
def enum_fn(x: Color, y: Color) -> bool:
if x == Color.RED:
return True
return x == y
m = torch.jit.script(enum_fn)
print("Eager: ", enum_fn(Color.RED, Color.GREEN))
print("TorchScript: ", m(Color.RED, Color.GREEN))
示例 2
以下示例展示了受限枚举子类化的案例,其中 BaseColor
没有定义任何成员,因此可以被 Color
子类化。
import torch
from enum import Enum
class BaseColor(Enum):
def foo(self):
pass
class Color(BaseColor):
RED = 1
GREEN = 2
def enum_fn(x: Color, y: Color) -> bool:
if x == Color.RED:
return True
return x == y
m = torch.jit.script(enum_fn)
print("TorchScript: ", m(Color.RED, Color.GREEN))
print("Eager: ", enum_fn(Color.RED, Color.GREEN))
TorchScript 模块类¶
TSModuleType
是一种特殊的类类型,它从在 TorchScript 外部创建的对象实例中推断出来。 TSModuleType
由对象实例的 Python 类命名。 Python 类的 __init__()
方法不被视为 TorchScript 方法,因此它不需要符合 TorchScript 的类型检查规则。
模块实例类的类型模式直接从实例对象(在 TorchScript 范围之外创建)构建,而不是像自定义类那样从 __init__()
中推断。 两个相同实例类类型的对象可能遵循两种不同的类型模式。
从这个意义上讲,TSModuleType
并不是真正的静态类型。 因此,出于类型安全考虑,TSModuleType
不能用于 TorchScript 类型注释,也不能与 TSType
组合。
模块实例类¶
TorchScript 模块类型表示用户定义的 PyTorch 模块实例的类型模式。 在对 PyTorch 模块进行脚本化时,模块对象始终在 TorchScript 外部创建(即作为参数传递给 forward
)。 Python 模块类被视为模块实例类,因此 Python 模块类的 __init__()
方法不受 TorchScript 的类型检查规则约束。
TSModuleType ::= "class" Identifier "(torch.nn.Module)" ":"
ClassBodyDefinition
在哪里
forward()
和其他使用@torch.jit.export
装饰的方法必须可以编译为 TorchScript IR 并且符合 TorchScript 的类型检查规则。
与自定义类不同,只有模块类型的前向方法和其他使用 @torch.jit.export
装饰的方法需要可编译。 最重要的是,__init__()
不被视为 TorchScript 方法。 因此,模块类型构造函数不能在 TorchScript 范围内调用。 相反,TorchScript 模块对象始终在外部构造,并传递给 torch.jit.script(ModuleObj)
。
示例 1
此示例说明了模块类型的一些功能。
TestModule
实例是在 TorchScript 范围之外创建的(即在调用torch.jit.script
之前)。__init__()
不被视为 TorchScript 方法,因此它不需要被注释,可以包含任意 Python 代码。 此外,实例类的__init__()
方法不能在 TorchScript 代码中调用。 因为TestModule
实例是在 Python 中实例化的,在这个例子中,TestModule(2.0)
和TestModule(2)
创建了两个实例,它们的 data 属性具有不同的类型。 对于TestModule(2.0)
,self.x
的类型是float
,而对于TestModule(2.0)
,self.y
的类型是int
。TorchScript 会自动编译其他方法(例如
mul()
),这些方法由使用@torch.jit.export
或forward()
方法注释的方法调用。TorchScript 程序的入口点是模块类型的前向方法
forward()
,被注释为torch.jit.script
的函数,或被注释为torch.jit.export
的方法。
import torch
class TestModule(torch.nn.Module):
def __init__(self, v):
super().__init__()
self.x = v
def forward(self, inc: int):
return self.x + inc
m = torch.jit.script(TestModule(1))
print(f"First instance: {m(3)}")
m = torch.jit.script(TestModule(torch.ones([5])))
print(f"Second instance: {m(3)}")
上面的示例产生以下输出
First instance: 4
Second instance: tensor([4., 4., 4., 4., 4.])
示例 2
以下示例展示了模块类型的错误用法。 具体来说,此示例在 TorchScript 范围内调用 TestModule
的构造函数。
import torch
class TestModule(torch.nn.Module):
def __init__(self, v):
super().__init__()
self.x = v
def forward(self, x: int):
return self.x + x
class MyModel:
def __init__(self, v: int):
self.val = v
@torch.jit.export
def doSomething(self, val: int) -> int:
# error: should not invoke the constructor of module type
myModel = TestModule(self.val)
return myModel(val)
# m = torch.jit.script(MyModel(2)) # Results in below RuntimeError
# RuntimeError: Could not get name of python class object
类型注释¶
由于 TorchScript 是静态类型的,因此程序员需要在 TorchScript 代码的策略点进行类型注释,以便每个局部变量或实例数据属性都具有静态类型,每个函数和方法都具有静态类型的签名。
何时进行类型注释¶
一般来说,类型注释只需要在无法自动推断出静态类型的地方(例如方法或函数的参数或有时是返回值)。 局部变量和数据属性的类型通常会从它们的赋值语句中自动推断。 有时推断出的类型可能过于严格,例如,x
通过赋值 x = None
被推断为 NoneType
,而 x
实际上用作 Optional
。 在这种情况下,可能需要类型注释来覆盖自动推断,例如,x: Optional[int] = None
。 请注意,即使局部变量或数据属性的类型可以自动推断,注释类型始终是安全的。 注释类型必须与 TorchScript 的类型检查一致。
当参数、局部变量或数据属性没有类型注释,并且其类型无法自动推断时,TorchScript 假设它是 TensorType
,List[TensorType]
或 Dict[str, TensorType]
的默认类型。
注释函数签名¶
由于参数可能无法从函数体(包括函数和方法)自动推断,因此需要对它们进行类型注释。 否则,它们将假设默认类型 TensorType
。
TorchScript 支持两种方法和函数签名类型注释的风格。
Python3 风格直接在签名上注释类型。 因此,它允许单个参数不加注释(其类型将是
TensorType
的默认类型),或允许返回值类型不加注释(其类型将被自动推断)。
Python3Annotation ::= "def" Identifier [ "(" ParamAnnot* ")" ] [ReturnAnnot] ":"
FuncOrMethodBody
ParamAnnot ::= Identifier [ ":" TSType ] ","
ReturnAnnot ::= "->" TSType
请注意,当使用 Python3 风格时,类型 self
会被自动推断,不应被注释。
Mypy 风格将类型注释为函数/方法声明正下方的注释。 在 Mypy 风格中,由于参数名称没有出现在注释中,因此所有参数都必须进行注释。
MyPyAnnotation ::= "# type:" "(" ParamAnnot* ")" [ ReturnAnnot ]
ParamAnnot ::= TSType ","
ReturnAnnot ::= "->" TSType
示例 1
在此示例中。
a
未被注释,并假设默认类型为TensorType
。b
被注释为类型int
。返回值类型未被注释,并被自动推断为类型
TensorType
(基于返回的值的类型)。
import torch
def f(a, b: int):
return a+b
m = torch.jit.script(f)
print("TorchScript:", m(torch.ones([6]), 100))
示例 2
以下示例使用 Mypy 风格注释。 请注意,即使其中一些参数假设默认类型,也必须注释参数或返回值。
import torch
def f(a, b):
# type: (torch.Tensor, int) → torch.Tensor
return a+b
m = torch.jit.script(f)
print("TorchScript:", m(torch.ones([6]), 100))
注释变量和数据属性¶
一般来说,数据属性(包括类和实例数据属性)和局部变量的类型可以从赋值语句中自动推断。 但是,有时如果变量或属性与不同类型的值相关联(例如作为 None
或 TensorType
),那么它们可能需要被显式类型注释为更广泛的类型,例如 Optional[int]
或 Any
。
局部变量¶
局部变量可以根据 Python3 类型模块注释规则进行注释,即
LocalVarAnnotation ::= Identifier [":" TSType] "=" Expr
一般来说,局部变量的类型可以自动推断。 但是,在某些情况下,您可能需要为可能与不同的具体类型相关联的局部变量注释一个多类型。 典型的多类型包括 Optional[T]
和 Any
。
示例
import torch
def f(a, setVal: bool):
value: Optional[torch.Tensor] = None
if setVal:
value = a
return value
ones = torch.ones([6])
m = torch.jit.script(f)
print("TorchScript:", m(ones, True), m(ones, False))
实例数据属性¶
对于 ModuleType
类,实例数据属性可以根据 Python3 类型模块注释规则进行注释。 实例数据属性可以使用 Final
(可选)进行注释为最终属性。
"class" ClassIdentifier "(torch.nn.Module):"
InstanceAttrIdentifier ":" ["Final("] TSType [")"]
...
在哪里
InstanceAttrIdentifier
是实例属性的名称。Final
表示该属性不能在__init__
之外重新分配,也不能在子类中被覆盖。
示例
import torch
class MyModule(torch.nn.Module):
offset_: int
def __init__(self, offset):
self.offset_ = offset
...
类型注释 API¶
torch.jit.annotate(T, expr)
¶
此 API 将类型 T
注释到表达式 expr
。这通常用于表达式的默认类型不是程序员想要使用的类型时。例如,空列表(字典)的默认类型为 List[TensorType]
(Dict[TensorType, TensorType]
),但有时它可能用于初始化其他类型列表。另一个常见用例是为 tensor.tolist()
的返回类型添加注释。但是,请注意,它不能用于注释 __init__ 中模块属性的类型;应改为使用 torch.jit.Attribute
。
示例
在此示例中,[]
通过 torch.jit.annotate
声明为整数列表(而不是假设 []
为 List[TensorType]
的默认类型)。
import torch
from typing import List
def g(l: List[int], val: int):
l.append(val)
return l
def f(val: int):
l = g(torch.jit.annotate(List[int], []), val)
return l
m = torch.jit.script(f)
print("Eager:", f(3))
print("TorchScript:", m(3))
有关更多信息,请参见 torch.jit.annotate()
。
类型注释附录¶
TorchScript 类型系统定义¶
TSAllType ::= TSType | TSModuleType
TSType ::= TSMetaType | TSPrimitiveType | TSStructuralType | TSNominalType
TSMetaType ::= "Any"
TSPrimitiveType ::= "int" | "float" | "double" | "complex" | "bool" | "str" | "None"
TSStructuralType ::= TSTuple | TSNamedTuple | TSList | TSDict | TSOptional |
TSUnion | TSFuture | TSRRef | TSAwait
TSTuple ::= "Tuple" "[" (TSType ",")* TSType "]"
TSNamedTuple ::= "namedtuple" "(" (TSType ",")* TSType ")"
TSList ::= "List" "[" TSType "]"
TSOptional ::= "Optional" "[" TSType "]"
TSUnion ::= "Union" "[" (TSType ",")* TSType "]"
TSFuture ::= "Future" "[" TSType "]"
TSRRef ::= "RRef" "[" TSType "]"
TSAwait ::= "Await" "[" TSType "]"
TSDict ::= "Dict" "[" KeyType "," TSType "]"
KeyType ::= "str" | "int" | "float" | "bool" | TensorType | "Any"
TSNominalType ::= TSBuiltinClasses | TSCustomClass | TSEnum
TSBuiltinClass ::= TSTensor | "torch.device" | "torch.stream"|
"torch.dtype" | "torch.nn.ModuleList" |
"torch.nn.ModuleDict" | ...
TSTensor ::= "torch.tensor" and subclasses
不支持的类型构造¶
TorchScript 不支持 Python3 typing 模块的所有功能和类型。在本文档中未明确说明的 typing 模块的任何功能均不受支持。下表总结了在 TorchScript 中不受支持或受限制支持的 typing
构造。
项目 |
描述 |
|
开发中 |
|
不支持 |
|
不支持 |
|
不支持 |
|
不支持 |
|
支持模块属性、类属性和注释,但不支持函数。 |
|
不支持 |
|
开发中 |
类型别名 |
不支持 |
名义类型 |
开发中 |
结构化类型 |
不支持 |
NewType |
不支持 |
泛型 |
不支持 |
表达式¶
以下部分描述了 TorchScript 中支持的表达式的语法。它仿照了 Python 语言参考中的表达式章节。
算术转换¶
TorchScript 中执行了许多隐式类型转换。
具有
float
或int
数据类型的Tensor
可以隐式转换为FloatType
或IntType
的实例,前提是它的大小为 0,没有将require_grad
设置为True
,并且不需要缩小。StringType
的实例可以隐式转换为DeviceType
。上面两点中的隐式转换规则可以应用于
TupleType
的实例,以生成具有适当包含类型的ListType
实例。
显式转换可以使用 float
、int
、bool
和 str
内置函数来调用,这些函数接受原始数据类型作为参数,并且如果实现了 __bool__
、__str__
等,则可以接受用户定义的类型。
原子¶
原子是表达式的最基本元素。
atom ::= identifier | literal | enclosure
enclosure ::= parenth_form | list_display | dict_display
标识符¶
决定 TorchScript 中什么是合法标识符的规则与它们的 Python 对应项 相同。
文字¶
literal ::= stringliteral | integer | floatnumber
对文字的求值将生成具有特定值的适当类型的对象(对于浮点数,将应用必要的近似值)。文字是不可变的,对相同文字的多次求值可能会获得相同的对象或具有相同值的不同的对象。 stringliteral、 integer 和 floatnumber 的定义与其 Python 对应项相同。
带括号的形式¶
parenth_form ::= '(' [expression_list] ')'
带括号的表达式列表将生成表达式列表生成的任何内容。如果列表至少包含一个逗号,则它将生成一个 Tuple
;否则,它将生成表达式列表中的单个表达式。一对空的圆括号将生成一个空的 Tuple
对象 (Tuple[]
)。
列表和字典显示¶
list_comprehension ::= expression comp_for
comp_for ::= 'for' target_list 'in' or_expr
list_display ::= '[' [expression_list | list_comprehension] ']'
dict_display ::= '{' [key_datum_list | dict_comprehension] '}'
key_datum_list ::= key_datum (',' key_datum)*
key_datum ::= expression ':' expression
dict_comprehension ::= key_datum comp_for
列表和字典可以通过显式列出容器内容或通过提供关于如何通过一组循环指令(即 *理解*)计算它们的内容来构造。理解在语义上等效于使用 for 循环并追加到正在进行的列表。理解隐式创建自己的范围以确保目标列表的项目不会泄漏到封闭范围中。在显式列出容器项目的情况下,表达式列表中的表达式按从左到右的顺序求值。如果 dict_display
中的键重复,该 dict_display
具有 key_datum_list
,则生成的字典使用列表中使用重复键的最右侧数据的键值。
原语¶
primary ::= atom | attributeref | subscription | slicing | call
订阅¶
subscription ::= primary '[' expression_list ']'
primary
必须求值为支持订阅的对象。
如果 primary 为
List
、Tuple
或str
,则表达式列表必须求值为整数或切片。如果 primary 为
Dict
,则表达式列表必须求值为与Dict
的键类型相同的类型的对象。如果 primary 为
ModuleList
,则表达式列表必须为integer
文字。如果 primary 为
ModuleDict
,则表达式必须为stringliteral
。
切片¶
切片选择 str
、Tuple
、List
或 Tensor
中的一系列项目。切片可以用作赋值语句或 del
语句中的表达式或目标。
slicing ::= primary '[' slice_list ']'
slice_list ::= slice_item (',' slice_item)* [',']
slice_item ::= expression | proper_slice
proper_slice ::= [expression] ':' [expression] [':' [expression] ]
在其切片列表中包含多个切片项目的切片只能与求值为 Tensor
类型对象的 primary 一起使用。
调用¶
call ::= primary '(' argument_list ')'
argument_list ::= args [',' kwargs] | kwargs
args ::= [arg (',' arg)*]
kwargs ::= [kwarg (',' kwarg)*]
kwarg ::= arg '=' expression
arg ::= identifier
primary
必须反糖或求值为可调用对象。所有参数表达式都在尝试调用之前求值。
幂运算符¶
power ::= primary ['**' u_expr]
幂运算符的语义与内置的 pow 函数相同(不支持);它计算其左操作数的右操作数次幂。它比左边的单目运算符绑定得更紧密,但比右边的单目运算符绑定得更松散;即 -2 ** -3 == -(2 ** (-3))
。左右操作数可以是 int
、float
或 Tensor
。标量在标量-张量/张量-标量指数运算操作的情况下被广播,张量-张量指数运算按元素进行,没有任何广播。
一元和算术位运算¶
u_expr ::= power | '-' power | '~' power
一元 -
运算符生成其参数的否定。一元 ~
运算符生成其参数的按位求反。-
可以与 int
、float
和 Tensor
的 int
和 float
一起使用。~
只能与 int
和 Tensor
的 int
一起使用。
二元算术运算¶
m_expr ::= u_expr | m_expr '*' u_expr | m_expr '@' m_expr | m_expr '//' u_expr | m_expr '/' u_expr | m_expr '%' u_expr
a_expr ::= m_expr | a_expr '+' m_expr | a_expr '-' m_expr
二元算术运算符可以对 Tensor
、int
和 float
进行操作。对于张量-张量操作,两个参数必须具有相同的形状。对于标量-张量或张量-标量操作,标量通常会广播到张量的尺寸。除法运算只能接受标量作为其右侧参数,并且不支持广播。 @
运算符用于矩阵乘法,并且只对 Tensor
参数进行操作。乘法运算符 (*
) 可以与列表和整数一起使用,以获得将原始列表重复一定次数的结果。
移位操作¶
shift_expr ::= a_expr | shift_expr ( '<<' | '>>' ) a_expr
这些运算符接受两个 int
参数,两个 Tensor
参数,或一个 Tensor
参数和一个 int
或 float
参数。在所有情况下,右移 n
被定义为对 pow(2, n)
进行地板除,而左移 n
被定义为乘以 pow(2, n)
。当两个参数都是 Tensors
时,它们必须具有相同的形状。当一个参数是标量,另一个参数是 Tensor
时,标量会逻辑地广播以匹配 Tensor
的大小。
二元按位操作¶
and_expr ::= shift_expr | and_expr '&' shift_expr
xor_expr ::= and_expr | xor_expr '^' and_expr
or_expr ::= xor_expr | or_expr '|' xor_expr
&
运算符计算其参数的按位与,^
计算按位异或,而 |
计算按位或。两个操作数必须是 int
或 Tensor
,或者左操作数必须是 Tensor
,而右操作数必须是 int
。当两个操作数都是 Tensor
时,它们必须具有相同的形状。当右操作数为 int
,而左操作数为 Tensor
时,右操作数会逻辑地广播以匹配 Tensor
的形状。
比较¶
comparison ::= or_expr (comp_operator or_expr)*
comp_operator ::= '<' | '>' | '==' | '>=' | '<=' | '!=' | 'is' ['not'] | ['not'] 'in'
比较会产生一个布尔值 (True
或 False
),或者如果其中一个操作数是 Tensor
,则会产生一个布尔 Tensor
。比较可以任意地链接,只要它们不会产生具有多个元素的布尔 Tensors
。 a op1 b op2 c ...
等价于 a op1 b and b op2 c and ...
。
值比较¶
运算符 <
、>
、==
、>=
、<=
和 !=
比较两个对象的数值。这两个对象通常需要是相同类型,除非在对象之间存在隐式类型转换。如果在用户定义的类型上定义了丰富的比较方法(例如,__lt__
),则可以比较它们。内置类型比较的行为类似于 Python。
数字以数学方式进行比较。
字符串以字典序进行比较。
lists
、tuples
和dicts
只能与相同类型且使用相应元素的比较运算符进行比较。
成员测试操作¶
运算符 in
和 not in
用于测试成员资格。 x in s
如果 x
是 s
的成员,则评估为 True
,否则评估为 False
。 x not in s
等价于 not x in s
。此运算符支持 lists
、dicts
和 tuples
,并且可以在用户定义的类型上使用,前提是它们实现了 __contains__
方法。
身份比较¶
对于除 int
、double
、bool
和 torch.device
之外的所有类型,运算符 is
和 is not
用于测试对象的标识;x is y
当且仅当 x
和 y
是同一个对象时为 True
。对于所有其他类型,is
等效于使用 ==
对它们进行比较。 x is not y
会产生 x is y
的逆。
布尔运算¶
or_test ::= and_test | or_test 'or' and_test
and_test ::= not_test | and_test 'and' not_test
not_test ::= 'bool' '(' or_expr ')' | comparison | 'not' not_test
用户定义的对象可以通过实现 __bool__
方法来自定义它们转换为 bool
的方式。运算符 not
如果其操作数为假,则产生 True
,否则产生 False
。表达式 x
and y
首先评估 x
;如果它为 False
,则返回其值 (False
);否则,评估 y
并返回其值 (False
或 True
)。表达式 x
or y
首先评估 x
;如果它为 True
,则返回其值 (True
);否则,评估 y
并返回其值 (False
或 True
)。
条件表达式¶
conditional_expression ::= or_expr ['if' or_test 'else' conditional_expression]
expression ::= conditional_expression
表达式 x if c else y
首先评估条件 c
而不是 x。如果 c
为 True
,则评估 x
并返回其值;否则,评估 y
并返回其值。与 if 语句一样,x
和 y
必须评估为相同类型的值。
表达式列表¶
expression_list ::= expression (',' expression)* [',']
starred_item ::= '*' primary
星号项只能出现在赋值语句的左侧,例如,a, *b, c = ...
。
简单语句¶
以下部分描述了 TorchScript 中支持的简单语句的语法。它是根据 Python 语言参考中的简单语句章节 建模的。
表达式语句¶
expression_stmt ::= starred_expression
starred_expression ::= expression | (starred_item ",")* [starred_item]
starred_item ::= assignment_expression | "*" or_expr
赋值语句¶
assignment_stmt ::= (target_list "=")+ (starred_expression)
target_list ::= target ("," target)* [","]
target ::= identifier
| "(" [target_list] ")"
| "[" [target_list] "]"
| attributeref
| subscription
| slicing
| "*" target
增强赋值语句¶
augmented_assignment_stmt ::= augtarget augop (expression_list)
augtarget ::= identifier | attributeref | subscription
augop ::= "+=" | "-=" | "*=" | "/=" | "//=" | "%=" |
"**="| ">>=" | "<<=" | "&=" | "^=" | "|="
带注释的赋值语句¶
annotated_assignment_stmt ::= augtarget ":" expression
["=" (starred_expression)]
raise
语句¶
raise_stmt ::= "raise" [expression ["from" expression]]
TorchScript 中的 raise 语句不支持 try\except\finally
。
assert
语句¶
assert_stmt ::= "assert" expression ["," expression]
TorchScript 中的 assert 语句不支持 try\except\finally
。
return
语句¶
return_stmt ::= "return" [expression_list]
TorchScript 中的 return 语句不支持 try\except\finally
。
del
语句¶
del_stmt ::= "del" target_list
pass
语句¶
pass_stmt ::= "pass"
print
语句¶
print_stmt ::= "print" "(" expression [, expression] [.format{expression_list}] ")"
break
语句¶
break_stmt ::= "break"
continue
语句:¶
continue_stmt ::= "continue"
复合语句¶
以下部分描述了 TorchScript 支持的复合语句语法。该部分还重点介绍了 Torchscript 与常规 Python 语句的不同之处。它是根据 Python 语言参考中的复合语句章节 建模的。
if
语句¶
Torchscript 支持基本 if/else
和三元 if/else
。
基本 if/else
语句¶
if_stmt ::= "if" assignment_expression ":" suite
("elif" assignment_expression ":" suite)
["else" ":" suite]
elif
语句可以重复任意次数,但它需要位于 else
语句之前。
三元 if/else
语句¶
if_stmt ::= return [expression_list] "if" assignment_expression "else" [expression_list]
示例 1
一个维度为 1 的 tensor
会被提升为 bool
import torch
@torch.jit.script
def fn(x: torch.Tensor):
if x: # The tensor gets promoted to bool
return True
return False
print(fn(torch.rand(1)))
上面的示例产生以下输出
True
示例 2
一个维度为多个的 tensor
不会被提升为 bool
import torch
# Multi dimensional Tensors error out.
@torch.jit.script
def fn():
if torch.rand(2):
print("Tensor is available")
if torch.rand(4,5,6):
print("Tensor is available")
print(fn())
运行上面的代码会产生以下 RuntimeError
。
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
@torch.jit.script
def fn():
if torch.rand(2):
~~~~~~~~~~~~ <--- HERE
print("Tensor is available")
RuntimeError: Boolean value of Tensor with more than one value is ambiguous
如果条件变量被注解为 final
,则根据条件变量的评估结果,真或假分支中的一个会被评估。
示例 3
在这个例子中,只有真分支被评估,因为 a
被注解为 final
并被设置为 True
import torch
a : torch.jit.final[Bool] = True
if a:
return torch.empty(2,3)
else:
return []
while
语句¶
while_stmt ::= "while" assignment_expression ":" suite
while…else 语句在 Torchscript 中不受支持。它会导致 RuntimeError
。
for-in
语句¶
for_stmt ::= "for" target_list "in" expression_list ":" suite
["else" ":" suite]
for...else
语句在 Torchscript 中不受支持。它会导致 RuntimeError
。
示例 1
在元组上的 for 循环:这些循环会展开循环,为元组中的每个成员生成一个主体。主体必须对每个成员进行正确的类型检查。
import torch
from typing import Tuple
@torch.jit.script
def fn():
tup = (3, torch.ones(4))
for x in tup:
print(x)
fn()
上面的示例产生以下输出
3
1
1
1
1
[ CPUFloatType{4} ]
示例 2
在列表上的 for 循环:在 nn.ModuleList
上的 for 循环将在编译时展开循环体,并包含模块列表中的每个成员。
class SubModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.weight = nn.Parameter(torch.randn(2))
def forward(self, input):
return self.weight + input
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.mods = torch.nn.ModuleList([SubModule() for i in range(10)])
def forward(self, v):
for module in self.mods:
v = module(v)
return v
model = torch.jit.script(MyModule())
with
语句¶
with
语句用于用上下文管理器定义的方法包装代码块的执行。
with_stmt ::= "with" with_item ("," with_item) ":" suite
with_item ::= expression ["as" target]
如果
with
语句中包含目标,则上下文管理器的__enter__()
的返回值会被分配给它。与 python 不同的是,如果异常导致代码块退出,它的类型、值和回溯不会作为参数传递给__exit__()
。会提供三个None
参数。try
、except
和finally
语句在with
代码块中不受支持。在
with
代码块中引发的异常无法被抑制。
tuple
语句¶
tuple_stmt ::= tuple([iterables])
TorchScript 中的可迭代类型包括
Tensors
、lists
、tuples
、dictionaries
、strings
、torch.nn.ModuleList
和torch.nn.ModuleDict
。你不能使用这个内置函数将列表转换为元组。
将所有输出解包到一个元组中,由以下内容涵盖
abc = func() # Function that returns a tuple
a,b = func()
getattr
语句¶
getattr_stmt ::= getattr(object, name[, default])
属性名称必须是文字字符串。
不支持模块类型对象(例如,torch._C)。
不支持自定义类对象(例如,torch.classes.*)。
hasattr
语句¶
hasattr_stmt ::= hasattr(object, name)
属性名称必须是文字字符串。
不支持模块类型对象(例如,torch._C)。
不支持自定义类对象(例如,torch.classes.*)。
zip
语句¶
zip_stmt ::= zip(iterable1, iterable2)
参数必须是可迭代的。
支持两个具有相同外部容器类型但长度不同的可迭代对象。
示例 1
这两个可迭代对象必须具有相同的容器类型
a = [1, 2] # List
b = [2, 3, 4] # List
zip(a, b) # works
示例 2
此示例失败,因为可迭代对象具有不同的容器类型
a = (1, 2) # Tuple
b = [2, 3, 4] # List
zip(a, b) # Runtime error
运行上面的代码会产生以下 RuntimeError
。
RuntimeError: Can not iterate over a module list or
tuple with a value that does not have a statically determinable length.
示例 3
支持具有相同容器类型但数据类型不同的两个可迭代对象
a = [1.3, 2.4]
b = [2, 3, 4]
zip(a, b) # Works
TorchScript 中的可迭代类型包括 Tensors
、lists
、tuples
、dictionaries
、strings
、torch.nn.ModuleList
和 torch.nn.ModuleDict
。
enumerate
语句¶
enumerate_stmt ::= enumerate([iterable])
参数必须是可迭代的。
TorchScript 中的可迭代类型包括
Tensors
、lists
、tuples
、dictionaries
、strings
、torch.nn.ModuleList
和torch.nn.ModuleDict
。
Python 值¶
解析规则¶
当给定一个 Python 值时,TorchScript 会尝试通过以下五种不同的方式解析它
- 可编译的 Python 实现
当 Python 值由 TorchScript 可以编译的 Python 实现支持时,TorchScript 会编译并使用底层的 Python 实现。
示例:
torch.jit.Attribute
- 运算符 Python 包装器
当 Python 值是原生 PyTorch 运算符的包装器时,TorchScript 会发出相应的运算符。
示例:
torch.jit._logging.add_stat_value
- Python 对象身份匹配
对于 TorchScript 支持的一组有限的
torch.*
API 调用(以 Python 值的形式),TorchScript 会尝试将 Python 值与集合中的每个项目进行匹配。当匹配成功时,TorchScript 会生成一个相应的
SugaredValue
实例,该实例包含这些值的降低逻辑。示例:
torch.jit.isinstance()
- 名称匹配
对于 Python 内置函数和常量,TorchScript 会通过名称识别它们,并创建一个相应的
SugaredValue
实例来实现它们的功能。示例:
all()
- 值快照
对于来自无法识别模块的 Python 值,TorchScript 会尝试对该值进行快照,并将其转换为正在编译的函数或方法图中的一个常量。
示例:
math.pi
Python 内置函数支持¶
内置函数 |
支持级别 |
备注 |
---|---|---|
|
部分 |
只支持 |
|
完整 |
|
|
完整 |
|
|
无 |
|
|
部分 |
只支持 |
|
部分 |
只支持 |
|
无 |
|
|
无 |
|
|
无 |
|
|
无 |
|
|
部分 |
只支持 ASCII 字符集。 |
|
完整 |
|
|
无 |
|
|
无 |
|
|
无 |
|
|
完整 |
|
|
无 |
|
|
完整 |
|
|
完整 |
|
|
无 |
|
|
无 |
|
|
无 |
|
|
部分 |
不尊重 |
|
部分 |
不支持手动索引指定。 | 不支持格式类型修饰符。 |
|
无 |
|
|
部分 |
属性名称必须是字符串文字。 |
|
无 |
|
|
部分 |
属性名称必须是字符串文字。 |
|
完整 |
|
|
部分 |
只支持 |
|
完整 |
只支持 |
|
无 |
|
|
部分 |
不支持 |
|
完整 |
|
|
无 |
|
|
无 |
|
|
完整 |
|
|
完整 |
|
|
部分 |
只支持 ASCII 字符集。 |
|
完整 |
|
|
部分 |
不支持 |
|
无 |
|
|
完整 |
|
|
无 |
|
|
无 |
|
|
部分 |
不支持 |
|
无 |
|
|
无 |
|
|
完整 |
|
|
部分 |
不支持 |
|
完整 |
|
|
部分 |
不支持 |
|
完整 |
|
|
部分 |
它只能在 |
|
无 |
|
|
无 |
|
|
完整 |
|
|
无 |
torch.* API¶
远程过程调用¶
TorchScript 支持一部分 RPC API,这些 API 支持在指定的远程工作节点上而不是本地运行函数。
具体来说,以下 API 完全受支持
torch.distributed.rpc.rpc_sync()
rpc_sync()
对远程工作节点进行阻塞式 RPC 调用以运行函数。RPC 消息在发送和接收时与 Python 代码执行并行进行。有关其用法和示例的更多详细信息,请参见
rpc_sync()
.
torch.distributed.rpc.rpc_async()
rpc_async()
对远程工作节点进行非阻塞式 RPC 调用以运行函数。RPC 消息在发送和接收时与 Python 代码执行并行进行。有关其用法和示例的更多详细信息,请参见
rpc_async()
.
torch.distributed.rpc.remote()
remote.()
在工作节点上执行远程调用,并获得远程引用RRef
作为返回值。有关其用法和示例的更多详细信息,请参见
remote()
.
异步执行¶
TorchScript 使您能够创建异步计算任务,以更好地利用计算资源。这可以通过支持一些仅在 TorchScript 中可用的 API 来实现。
类型注解¶
TorchScript 是静态类型的。它提供并支持一组工具来帮助您注释变量和属性。
torch.jit.annotate()
在 Python 3 样式的类型提示无法正常工作的情况下,为 TorchScript 提供类型提示。
一个常见的例子是为
[]
等表达式注释类型。默认情况下,[]
被视为List[torch.Tensor]
。当需要其他类型时,您可以使用此代码提示 TorchScript:torch.jit.annotate(List[int], [])
。更多详细信息请参见
annotate()
torch.jit.Attribute
常见用例包括为
torch.nn.Module
属性提供类型提示。因为它们的__init__
方法不会被 TorchScript 解析,所以应该在模块的__init__
方法中使用torch.jit.Attribute
而不是torch.jit.annotate
。更多详细信息请参见
Attribute()
torch.jit.Final
Python 的
typing.Final
的别名。torch.jit.Final
仅出于向后兼容性原因保留。
元编程¶
TorchScript 提供了一组工具来促进元编程。
torch.jit.is_scripting()
返回一个布尔值,指示当前程序是通过
torch.jit.script
编译的还是通过其他方式编译的。当用在
assert
或if
语句中时,torch.jit.is_scripting()
评估为False
的范围或分支将不会被编译。它的值可以在编译时静态评估,因此通常用于
if
语句中,以阻止 TorchScript 编译其中一个分支。更多详细信息和示例请参见
is_scripting()
torch.jit.is_tracing()
返回一个布尔值,指示当前程序是通过
torch.jit.trace
/torch.jit.trace_module
跟踪的还是通过其他方式跟踪的。更多详细信息请参见
is_tracing()
@torch.jit.ignore
此装饰器指示编译器忽略函数或方法,将其保留为 Python 函数。
这允许您在模型中保留尚不支持 TorchScript 的代码。
如果从 TorchScript 调用了用
@torch.jit.ignore
装饰的函数,被忽略的函数将把调用分派给 Python 解释器。包含被忽略函数的模型无法导出。
更多详细信息和示例请参见
ignore()
@torch.jit.unused
此装饰器指示编译器忽略函数或方法,并将其替换为抛出异常。
这允许您在模型中保留尚不支持 TorchScript 的代码,同时仍然可以导出您的模型。
如果从 TorchScript 调用了用
@torch.jit.unused
装饰的函数,将抛出运行时错误。更多详细信息和示例请参见
unused()
类型细化¶
torch.jit.isinstance()
返回一个布尔值,指示变量是否为指定类型。
有关其用法和示例的更多详细信息,请参见
isinstance()
.