TorchScript 语言参考¶
本参考手册描述了 TorchScript 语言的语法和核心语义。TorchScript 是 Python 语言的一个静态类型子集。本文档解释了 TorchScript 中支持的 Python 特性,以及该语言与普通 Python 的区别。本参考手册中未提及的任何 Python 特性均不属于 TorchScript。TorchScript 特别关注在 PyTorch 中表示神经网络模型所需的 Python 特性。
术语¶
本文档使用以下术语
模式 |
注释 |
---|---|
|
表示给定符号的定义方式。 |
|
表示语法中的实际关键字和分隔符。 |
|
表示 A 或 B。 |
|
表示分组。 |
|
表示可选。 |
|
表示正则表达式,其中项 A 重复至少一次。 |
|
表示正则表达式,其中项 A 重复零次或多次。 |
类型系统¶
TorchScript 是 Python 的一个静态类型子集。TorchScript 与完整 Python 语言的最大区别在于,TorchScript 只支持表达神经网络模型所需的一小部分类型。
TorchScript 类型¶
TorchScript 类型系统包含 TSType
和 TSModuleType
,定义如下。
TSAllType ::= TSType | TSModuleType
TSType ::= TSMetaType | TSPrimitiveType | TSStructuralType | TSNominalType
TSType
表示大多数 TorchScript 类型,它们是可组合的,并且可以在 TorchScript 类型标注中使用。TSType
指代以下任何一种类型
元类型,例如
Any
基本类型,例如
int
,float
和str
结构类型,例如
Optional[int]
或List[MyClass]
名义类型(Python 类),例如
MyClass
(用户定义),torch.tensor
(内置)
TSModuleType
表示 torch.nn.Module
及其子类。它与 TSType
的处理方式不同,因为其类型模式部分从对象实例推断,部分从类定义推断。因此,TSModuleType
的实例可能不遵循相同的静态类型模式。出于类型安全考虑,TSModuleType
不能用作 TorchScript 类型标注,也不能与 TSType
组合使用。
元类型¶
元类型非常抽象,它们更像是类型约束而非具体类型。目前 TorchScript 定义了一种元类型 Any
,它表示任何 TorchScript 类型。
Any
类型¶
Any
类型表示任何 TorchScript 类型。Any
不指定任何类型约束,因此对 Any
不进行类型检查。因此,它可以绑定到任何 Python 或 TorchScript 数据类型(例如 int
、TorchScript tuple
或任何未被 script 的 Python 类)。
TSMetaType ::= "Any"
其中
Any
是 typing 模块中的 Python 类名。因此,要使用Any
类型,必须从typing
导入它(例如from typing import Any
)。由于
Any
可以表示任何 TorchScript 类型,因此允许对该类型的值执行的运算符集是有限的。
Any
类型支持的运算符¶
对
Any
类型数据的赋值。绑定到
Any
类型的参数或返回值。x is
,x is not
,其中x
是Any
类型。isinstance(x, Type)
,其中x
是Any
类型。Any
类型的数据是可打印的。如果数据是相同类型
T
的值列表且T
支持比较运算符,则List[Any]
类型的数据可能是可排序的。
与 Python 相比
Any
是 TorchScript 类型系统中约束最少的类型。从这个意义上说,它与 Python 中的 Object
类非常相似。但是,Any
只支持 Object
支持的运算符和方法的一个子集。
设计说明¶
当我们 script 一个 PyTorch 模块时,可能会遇到不参与脚本执行的数据。尽管如此,它仍然必须由类型模式描述。描述未使用数据(在脚本上下文中)的静态类型不仅麻烦,还可能导致不必要的 scripting 失败。Any
的引入是为了描述那些对于编译而言不需要精确静态类型的数据的类型。
示例 1
此示例说明了如何使用 Any
允许元组参数的第二个元素可以是任何类型。这是可能的,因为 x[1]
不涉及任何需要知道其精确类型的计算。
import torch
from typing import Tuple
from typing import Any
@torch.jit.export
def inc_first_element(x: Tuple[int, Any]):
return (x[0]+1, x[1])
m = torch.jit.script(inc_first_element)
print(m((1,2.0)))
print(m((1,(100,200))))
上面的示例产生以下输出
(2, 2.0)
(2, (100, 200))
元组的第二个元素是 Any
类型,因此可以绑定到多种类型。例如,(1, 2.0)
将 float 类型绑定到 Any
,就像在 Tuple[int, Any]
中一样;而 (1, (100, 200))
在第二次调用中将 tuple 绑定到 Any
。
示例 2
此示例说明了如何使用 isinstance
动态检查标注为 Any
类型的数据的类型
import torch
from typing import Any
def f(a:Any):
print(a)
return (isinstance(a, torch.Tensor))
ones = torch.ones([2])
m = torch.jit.script(f)
print(m(ones))
上面的示例产生以下输出
1
1
[ CPUFloatType{2} ]
True
基本类型¶
TorchScript 基本类型是表示单一类型值并具有单一预定义类型名称的类型。
TSPrimitiveType ::= "int" | "float" | "double" | "complex" | "bool" | "str" | "None"
结构类型¶
结构类型是结构化定义且没有用户定义名称(与名义类型不同)的类型,例如 Future[int]
。结构类型可以与任何 TSType
组合使用。
TSStructuralType ::= TSTuple | TSNamedTuple | TSList | TSDict |
TSOptional | TSUnion | TSFuture | TSRRef | TSAwait
TSTuple ::= "Tuple" "[" (TSType ",")* TSType "]"
TSNamedTuple ::= "namedtuple" "(" (TSType ",")* TSType ")"
TSList ::= "List" "[" TSType "]"
TSOptional ::= "Optional" "[" TSType "]"
TSUnion ::= "Union" "[" (TSType ",")* TSType "]"
TSFuture ::= "Future" "[" TSType "]"
TSRRef ::= "RRef" "[" TSType "]"
TSAwait ::= "Await" "[" TSType "]"
TSDict ::= "Dict" "[" KeyType "," TSType "]"
KeyType ::= "str" | "int" | "float" | "bool" | TensorType | "Any"
其中
Tuple
,List
,Optional
,Union
,Future
,Dict
表示在typing
模块中定义的 Python 类型类名。要使用这些类型名,必须从typing
导入它们(例如from typing import Tuple
)。namedtuple
表示 Python 类collections.namedtuple
或typing.NamedTuple
。Future
和RRef
表示 Python 类torch.futures
和torch.distributed.rpc
。Await
表示 Python 类torch._awaits._Await
与 Python 相比
除了可以与 TorchScript 类型组合外,这些 TorchScript 结构类型通常支持其 Python 对应类型的一组通用运算符和方法。
示例 1
此示例使用 typing.NamedTuple
语法定义元组
import torch
from typing import NamedTuple
from typing import Tuple
class MyTuple(NamedTuple):
first: int
second: int
def inc(x: MyTuple) -> Tuple[int, int]:
return (x.first+1, x.second+1)
t = MyTuple(first=1, second=2)
scripted_inc = torch.jit.script(inc)
print("TorchScript:", scripted_inc(t))
上面的示例产生以下输出
TorchScript: (2, 3)
示例 2
此示例使用 collections.namedtuple
语法定义元组
import torch
from typing import NamedTuple
from typing import Tuple
from collections import namedtuple
_AnnotatedNamedTuple = NamedTuple('_NamedTupleAnnotated', [('first', int), ('second', int)])
_UnannotatedNamedTuple = namedtuple('_NamedTupleAnnotated', ['first', 'second'])
def inc(x: _AnnotatedNamedTuple) -> Tuple[int, int]:
return (x.first+1, x.second+1)
m = torch.jit.script(inc)
print(inc(_UnannotatedNamedTuple(1,2)))
上面的示例产生以下输出
(2, 3)
示例 3
此示例说明了标注结构类型时的一个常见错误,即没有从 typing
模块导入复合类型类
import torch
# ERROR: Tuple not recognized because not imported from typing
@torch.jit.export
def inc(x: Tuple[int, int]):
return (x[0]+1, x[1]+1)
m = torch.jit.script(inc)
print(m((1,2)))
运行上述代码会产生以下 scripting 错误
File "test-tuple.py", line 5, in <module>
def inc(x: Tuple[int, int]):
NameError: name 'Tuple' is not defined
解决方法是在代码开头添加 from typing import Tuple
这一行。
名义类型¶
TorchScript 名义类型是 Python 类。这些类型之所以称为名义类型,是因为它们是使用自定义名称声明的,并使用类名进行比较。名义类进一步分为以下几类
TSNominalType ::= TSBuiltinClasses | TSCustomClass | TSEnum
其中,TSCustomClass
和 TSEnum
必须可以编译为 TorchScript 中间表示 (IR)。这是由类型检查器强制执行的。
内置类¶
内置名义类型是其语义已内置到 TorchScript 系统中的 Python 类(例如 tensor 类型)。TorchScript 定义了这些内置名义类型的语义,并且通常只支持其 Python 类定义的方法或属性的一个子集。
TSBuiltinClass ::= TSTensor | "torch.device" | "torch.Stream" | "torch.dtype" |
"torch.nn.ModuleList" | "torch.nn.ModuleDict" | ...
TSTensor ::= "torch.Tensor" | "common.SubTensor" | "common.SubWithTorchFunction" |
"torch.nn.parameter.Parameter" | and subclasses of torch.Tensor
关于 torch.nn.ModuleList 和 torch.nn.ModuleDict 的特别说明¶
尽管 torch.nn.ModuleList
和 torch.nn.ModuleDict
在 Python 中定义为列表和字典,但在 TorchScript 中它们更像元组
在 TorchScript 中,
torch.nn.ModuleList
或torch.nn.ModuleDict
的实例是不可变的。迭代
torch.nn.ModuleList
或torch.nn.ModuleDict
的代码会完全展开,以便torch.nn.ModuleList
的元素或torch.nn.ModuleDict
的键可以是torch.nn.Module
的不同子类。
示例
以下示例重点介绍了几个内置 Torchscript 类(torch.*
)的使用
import torch
@torch.jit.script
class A:
def __init__(self):
self.x = torch.rand(3)
def f(self, y: torch.device):
return self.x.to(device=y)
def g():
a = A()
return a.f(torch.device("cpu"))
script_g = torch.jit.script(g)
print(script_g.graph)
自定义类¶
与内置类不同,自定义类的语义是用户定义的,并且整个类定义必须可编译为 TorchScript IR 并遵守 TorchScript 类型检查规则。
TSClassDef ::= [ "@torch.jit.script" ]
"class" ClassName [ "(object)" ] ":"
MethodDefinition |
[ "@torch.jit.ignore" ] | [ "@torch.jit.unused" ]
MethodDefinition
其中
类必须是新式类。Python 3 只支持新式类。在 Python 2.x 中,通过继承 object 来指定新式类。
实例数据属性是静态类型的,并且实例属性必须在
__init__()
方法内部通过赋值声明。不支持方法重载(即,不能有多个同名方法)。
MethodDefinition
必须可编译为 TorchScript IR 并遵守 TorchScript 的类型检查规则(即,所有方法都必须是有效的 TorchScript 函数,并且类属性定义必须是有效的 TorchScript 语句)。torch.jit.ignore
和torch.jit.unused
可用于忽略未完全 torchscript 化或应被编译器忽略的方法或函数。
与 Python 相比
与 Python 自定义类相比,TorchScript 自定义类有相当多的限制。TorchScript 自定义类
不支持类属性。
除了继承接口类型或 object 外,不支持子类化。
不支持方法重载。
必须在
__init__()
中初始化其所有实例属性;这是因为 TorchScript 通过在__init__()
中推断属性类型来构建类的静态模式。必须只包含满足 TorchScript 类型检查规则且可编译为 TorchScript IRs 的方法。
示例 1
如果 Python 类用 @torch.jit.script
标注,就可以在 TorchScript 中使用,类似于 TorchScript 函数的声明方式
@torch.jit.script
class MyClass:
def __init__(self, x: int):
self.x = x
def inc(self, val: int):
self.x += val
示例 2
TorchScript 自定义类类型必须通过在 __init__()
中赋值来“声明”其所有实例属性。如果一个实例属性未在 __init__()
中定义,但在类的其他方法中访问,则该类无法编译为 TorchScript 类,如下例所示
import torch
@torch.jit.script
class foo:
def __init__(self):
self.y = 1
# ERROR: self.x is not defined in __init__
def assign_x(self):
self.x = torch.rand(2, 3)
该类将编译失败并发出以下错误
RuntimeError:
Tried to set nonexistent attribute: x. Did you forget to initialize it in __init__()?:
def assign_x(self):
self.x = torch.rand(2, 3)
~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
示例 3
在此示例中,TorchScript 自定义类定义了一个类变量名,这是不允许的
import torch
@torch.jit.script
class MyClass(object):
name = "MyClass"
def __init__(self, x: int):
self.x = x
def fn(a: MyClass):
return a.name
这会导致以下编译时错误
RuntimeError:
'__torch__.MyClass' object has no attribute or method 'name'. Did you forget to initialize an attribute in __init__()?:
File "test-class2.py", line 10
def fn(a: MyClass):
return a.name
~~~~~~ <--- HERE
枚举类型¶
与自定义类类似,枚举类型的语义是用户定义的,并且整个类定义必须可编译为 TorchScript IR 并遵守 TorchScript 类型检查规则。
TSEnumDef ::= "class" Identifier "(enum.Enum | TSEnumType)" ":"
( MemberIdentifier "=" Value )+
( MethodDefinition )*
其中
值必须是类型为
int
,float
或str
的 TorchScript 字面量,并且必须是相同的 TorchScript 类型。TSEnumType
是 TorchScript 枚举类型的名称。与 Python enum 类似,TorchScript 允许受限的Enum
子类化,即只有当一个枚举类没有定义任何成员时,才允许对其进行子类化。
与 Python 相比
TorchScript 只支持
enum.Enum
。它不支持其他变体,例如enum.IntEnum
,enum.Flag
,enum.IntFlag
和enum.auto
。TorchScript 枚举成员的值必须是相同类型,并且只能是
int
,float
或str
类型,而 Python 枚举成员可以是任何类型。在 TorchScript 中,包含方法的枚举会被忽略。
示例 1
以下示例将类 Color
定义为 Enum
类型
import torch
from enum import Enum
class Color(Enum):
RED = 1
GREEN = 2
def enum_fn(x: Color, y: Color) -> bool:
if x == Color.RED:
return True
return x == y
m = torch.jit.script(enum_fn)
print("Eager: ", enum_fn(Color.RED, Color.GREEN))
print("TorchScript: ", m(Color.RED, Color.GREEN))
示例 2
以下示例展示了受限枚举子类化的情形,其中 BaseColor
未定义任何成员,因此可以被 Color
子类化
import torch
from enum import Enum
class BaseColor(Enum):
def foo(self):
pass
class Color(BaseColor):
RED = 1
GREEN = 2
def enum_fn(x: Color, y: Color) -> bool:
if x == Color.RED:
return True
return x == y
m = torch.jit.script(enum_fn)
print("TorchScript: ", m(Color.RED, Color.GREEN))
print("Eager: ", enum_fn(Color.RED, Color.GREEN))
TorchScript 模块类¶
TSModuleType
是一种特殊的类类型,它从 TorchScript 外部创建的对象实例推断而来。TSModuleType
的名称即为对象实例的 Python 类名。Python 类的 __init__()
方法不被视为 TorchScript 方法,因此无需遵守 TorchScript 的类型检查规则。
模块实例类的类型模式直接从实例对象(在 TorchScript 范围之外创建)构建,而不是像自定义类那样从 __init__()
推断。相同实例类类型的两个对象可能遵循两种不同的类型模式。
从这个意义上说,TSModuleType
并不是真正的静态类型。因此,出于类型安全考虑,TSModuleType
不能用于 TorchScript 类型标注,也不能与 TSType
组合使用。
模块实例类¶
TorchScript 模块类型表示用户定义的 PyTorch 模块实例的类型模式。当 script 一个 PyTorch 模块时,模块对象总是在 TorchScript 外部创建(即作为参数传递给 forward
)。Python 模块类被视为模块实例类,因此 Python 模块类的 __init__()
方法不受 TorchScript 类型检查规则的约束。
TSModuleType ::= "class" Identifier "(torch.nn.Module)" ":"
ClassBodyDefinition
其中
forward()
以及用@torch.jit.export
装饰的其他方法必须可编译为 TorchScript IR 并遵守 TorchScript 的类型检查规则。
与自定义类不同,只需模块类型的 forward 方法以及用 @torch.jit.export
或 forward()
装饰的其他方法可编译即可。最值得注意的是,__init__()
不被视为 TorchScript 方法。因此,模块类型构造函数不能在 TorchScript 范围内调用。相反,TorchScript 模块对象总是在外部构建并传递给 torch.jit.script(ModuleObj)
。
示例 1
此示例说明了模块类型的一些特性
TestModule
实例是在 TorchScript 范围之外创建的(即在调用torch.jit.script
之前)。__init__()
不被视为 TorchScript 方法,因此无需标注,并且可以包含任意 Python 代码。此外,实例类的__init__()
方法不能在 TorchScript 代码中调用。由于TestModule
实例是在 Python 中实例化的,在此示例中,TestModule(2.0)
和TestModule(2)
创建了两个实例,其数据属性具有不同的类型。TestModule(2.0)
的self.x
类型是float
,而TestModule(2)
的self.y
类型是int
。TorchScript 会自动编译通过
@torch.jit.export
标注的方法或forward()
方法调用的其他方法(例如mul()
)。TorchScript 程序的入口点可以是模块类型的
forward()
、标注为torch.jit.script
的函数,或标注为torch.jit.export
的方法。
import torch
class TestModule(torch.nn.Module):
def __init__(self, v):
super().__init__()
self.x = v
def forward(self, inc: int):
return self.x + inc
m = torch.jit.script(TestModule(1))
print(f"First instance: {m(3)}")
m = torch.jit.script(TestModule(torch.ones([5])))
print(f"Second instance: {m(3)}")
上面的示例产生以下输出
First instance: 4
Second instance: tensor([4., 4., 4., 4., 4.])
示例 2
以下示例展示了模块类型的不正确用法。具体来说,此示例在 TorchScript 范围内调用了 TestModule
的构造函数
import torch
class TestModule(torch.nn.Module):
def __init__(self, v):
super().__init__()
self.x = v
def forward(self, x: int):
return self.x + x
class MyModel:
def __init__(self, v: int):
self.val = v
@torch.jit.export
def doSomething(self, val: int) -> int:
# error: should not invoke the constructor of module type
myModel = TestModule(self.val)
return myModel(val)
# m = torch.jit.script(MyModel(2)) # Results in below RuntimeError
# RuntimeError: Could not get name of python class object
类型标注¶
由于 TorchScript 是静态类型的,程序员需要在 TorchScript 代码的*关键点*标注类型,以便每个局部变量或实例数据属性都具有静态类型,并且每个函数和方法都具有静态类型签名。
何时标注类型¶
通常,类型标注只在无法自动推断静态类型的地方需要(例如参数或有时是方法或函数的返回类型)。局部变量和数据属性的类型通常可以从它们的赋值语句中自动推断。有时推断的类型可能过于严格,例如,通过赋值 x = None
将 x
推断为 NoneType
,而实际上 x
被用作 Optional
。在这种情况下,可能需要类型标注来覆盖自动推断,例如 x: Optional[int] = None
。请注意,即使可以自动推断类型,为局部变量或数据属性标注类型也是安全的。标注的类型必须与 TorchScript 的类型检查一致。
当参数、局部变量或数据属性未标注类型且无法自动推断其类型时,TorchScript 假定其默认类型为 TensorType
, List[TensorType]
或 Dict[str, TensorType]
。
标注函数签名¶
由于参数可能无法从函数体(包括函数和方法)自动推断,因此需要进行类型标注。否则,它们将假定默认类型为 TensorType
。
TorchScript 支持两种方法和函数签名类型注解的风格
Python3 风格直接在签名上注解类型。因此,它允许某些参数不进行注解(其类型将是
TensorType
的默认类型),或者允许不注解返回类型(其类型将自动推断)。
Python3Annotation ::= "def" Identifier [ "(" ParamAnnot* ")" ] [ReturnAnnot] ":"
FuncOrMethodBody
ParamAnnot ::= Identifier [ ":" TSType ] ","
ReturnAnnot ::= "->" TSType
请注意,在使用 Python3 风格时,类型 self
会自动推断,不应进行注解。
Mypy 风格将类型注解为紧跟在函数/方法声明下方的一条注释。在 Mypy 风格中,由于参数名不出现在注解中,因此所有参数都必须进行注解。
MyPyAnnotation ::= "# type:" "(" ParamAnnot* ")" [ ReturnAnnot ]
ParamAnnot ::= TSType ","
ReturnAnnot ::= "->" TSType
示例 1
在此示例中
a
未注解,并假定其默认类型为TensorType
。b
注解为int
类型。返回类型未注解,并自动推断为
TensorType
类型(基于返回值的类型)。
import torch
def f(a, b: int):
return a+b
m = torch.jit.script(f)
print("TorchScript:", m(torch.ones([6]), 100))
示例 2
以下示例使用 Mypy 风格注解。请注意,即使参数或返回值假定为默认类型,也必须进行注解。
import torch
def f(a, b):
# type: (torch.Tensor, int) → torch.Tensor
return a+b
m = torch.jit.script(f)
print("TorchScript:", m(torch.ones([6]), 100))
注解变量和数据属性¶
通常,数据属性(包括类数据属性和实例数据属性)和局部变量的类型可以从赋值语句中自动推断。但是,有时如果变量或属性与不同类型的值相关联(例如,None
或 TensorType
),则可能需要将其显式注解为更“宽泛”的类型,例如 Optional[int]
或 Any
。
局部变量¶
局部变量可以根据 Python3 typing 模块的注解规则进行注解,即,
LocalVarAnnotation ::= Identifier [":" TSType] "=" Expr
通常,局部变量的类型可以自动推断。然而在某些情况下,对于可能与不同具体类型相关联的局部变量,您可能需要注解一个多类型。典型的多类型包括 Optional[T]
和 Any
。
示例
import torch
def f(a, setVal: bool):
value: Optional[torch.Tensor] = None
if setVal:
value = a
return value
ones = torch.ones([6])
m = torch.jit.script(f)
print("TorchScript:", m(ones, True), m(ones, False))
实例数据属性¶
对于 ModuleType
类,实例数据属性可以根据 Python3 typing 模块的注解规则进行注解。实例数据属性可以(可选地)通过 Final
注解为 final。
"class" ClassIdentifier "(torch.nn.Module):"
InstanceAttrIdentifier ":" ["Final("] TSType [")"]
...
其中
InstanceAttrIdentifier
是实例属性的名称。Final
表示该属性不能在__init__
之外重新赋值或在子类中被覆盖。
示例
import torch
class MyModule(torch.nn.Module):
offset_: int
def __init__(self, offset):
self.offset_ = offset
...
类型注解 API¶
torch.jit.annotate(T, expr)
¶
该 API 将类型 T
注解到表达式 expr
上。当表达式的默认类型不是程序员预期的类型时,经常使用此方法。例如,空列表(字典)的默认类型是 List[TensorType]
(Dict[TensorType, TensorType]
),但有时可能用于初始化其他类型的列表。另一个常见用例是注解 tensor.tolist()
的返回类型。然而请注意,它不能用于在 __init__ 中注解模块属性的类型;在这种情况下应使用 torch.jit.Attribute
。
示例
在此示例中,通过 torch.jit.annotate
将 []
声明为整数列表(而不是假定 []
的默认类型为 List[TensorType]
)
import torch
from typing import List
def g(l: List[int], val: int):
l.append(val)
return l
def f(val: int):
l = g(torch.jit.annotate(List[int], []), val)
return l
m = torch.jit.script(f)
print("Eager:", f(3))
print("TorchScript:", m(3))
有关更多信息,请参阅 torch.jit.annotate()
。
类型注解附录¶
TorchScript 类型系统定义¶
TSAllType ::= TSType | TSModuleType
TSType ::= TSMetaType | TSPrimitiveType | TSStructuralType | TSNominalType
TSMetaType ::= "Any"
TSPrimitiveType ::= "int" | "float" | "double" | "complex" | "bool" | "str" | "None"
TSStructuralType ::= TSTuple | TSNamedTuple | TSList | TSDict | TSOptional |
TSUnion | TSFuture | TSRRef | TSAwait
TSTuple ::= "Tuple" "[" (TSType ",")* TSType "]"
TSNamedTuple ::= "namedtuple" "(" (TSType ",")* TSType ")"
TSList ::= "List" "[" TSType "]"
TSOptional ::= "Optional" "[" TSType "]"
TSUnion ::= "Union" "[" (TSType ",")* TSType "]"
TSFuture ::= "Future" "[" TSType "]"
TSRRef ::= "RRef" "[" TSType "]"
TSAwait ::= "Await" "[" TSType "]"
TSDict ::= "Dict" "[" KeyType "," TSType "]"
KeyType ::= "str" | "int" | "float" | "bool" | TensorType | "Any"
TSNominalType ::= TSBuiltinClasses | TSCustomClass | TSEnum
TSBuiltinClass ::= TSTensor | "torch.device" | "torch.stream"|
"torch.dtype" | "torch.nn.ModuleList" |
"torch.nn.ModuleDict" | ...
TSTensor ::= "torch.tensor" and subclasses
不支持的 Typing 构造¶
TorchScript 不支持 Python3 typing 模块的所有特性和类型。本文档中未明确说明的 typing 模块功能均不受支持。下表总结了 TorchScript 中不受支持或受限支持的 typing
构造。
项目 |
描述 |
|
开发中 |
|
不支持 |
|
不支持 |
|
不支持 |
|
不支持 |
|
支持模块属性、类属性和注解,但不支持函数。 |
|
不支持 |
|
开发中 |
类型别名 |
不支持 |
名义类型 |
开发中 |
结构类型 |
不支持 |
NewType |
不支持 |
泛型 |
不支持 |
表达式¶
以下部分描述了 TorchScript 中支持的表达式语法。它是参照 Python 语言参考手册的表达式章节建模的。
算术转换¶
TorchScript 中执行多项隐式类型转换
如果一个
Tensor
的数据类型为float
或int
,并且其大小为 0,require_grad
未设置为True
,且不需要窄化,则可以将其隐式转换为FloatType
或IntType
的实例。StringType
的实例可以隐式转换为DeviceType
。上述两个要点的隐式转换规则可以应用于
TupleType
的实例,以生成具有适当包含类型的ListType
实例。
显式转换可以使用接受原始数据类型作为参数的内置函数 float
、int
、bool
和 str
调用,如果用户自定义类型实现了 __bool__
、__str__
等方法,也可以接受这些用户自定义类型。
原子¶
原子是表达式中最基本的元素。
atom ::= identifier | literal | enclosure
enclosure ::= parenth_form | list_display | dict_display
标识符¶
规定 TorchScript 中合法标识符的规则与其 Python 对应规则相同。
字面量¶
literal ::= stringliteral | integer | floatnumber
字面量的求值会产生一个具有特定值(必要时对浮点数应用近似值)的相应类型的对象。字面量是不可变的,对相同字面量进行多次求值可能会获得相同的对象或具有相同值的不同对象。stringliteral、integer 和 floatnumber 的定义方式与其 Python 对应部分相同。
带括号的形式¶
parenth_form ::= '(' [expression_list] ')'
带括号的表达式列表的求值结果与表达式列表的求值结果相同。如果列表中至少包含一个逗号,则其结果为一个 Tuple
;否则,结果为表达式列表中的单个表达式。一对空括号的结果为一个空 Tuple
对象(Tuple[]
)。
列表和字典显示¶
list_comprehension ::= expression comp_for
comp_for ::= 'for' target_list 'in' or_expr
list_display ::= '[' [expression_list | list_comprehension] ']'
dict_display ::= '{' [key_datum_list | dict_comprehension] '}'
key_datum_list ::= key_datum (',' key_datum)*
key_datum ::= expression ':' expression
dict_comprehension ::= key_datum comp_for
列表和字典可以通过显式列出容器内容或通过提供一组循环指令(即,推导式)来计算内容来构造。推导式在语义上等同于使用 for 循环并追加到一个正在构建的列表。推导式隐式创建自己的作用域,以确保目标列表的项不会泄漏到外部作用域。在显式列出容器项的情况下,表达式列表中的表达式从左到右求值。如果在包含 key_datum_list
的 dict_display
中某个键被重复,则结果字典使用列表中使用该重复键的最右边数据项的值。
主项¶
primary ::= atom | attributeref | subscription | slicing | call
下标¶
subscription ::= primary '[' expression_list ']'
The primary
必须求值为一个支持下标的对象。
如果 primary 是一个
List
、Tuple
或str
,则表达式列表必须求值为一个整数或切片。如果 primary 是一个
Dict
,则表达式列表必须求值为与Dict
的键类型相同的对象。如果 primary 是一个
ModuleList
,则表达式列表必须是一个integer
字面量。如果 primary 是一个
ModuleDict
,则表达式必须是一个stringliteral
。
切片¶
切片选择 str
、Tuple
、List
或 Tensor
中的一系列项。切片可以用作赋值或 del
语句中的表达式或目标。
slicing ::= primary '[' slice_list ']'
slice_list ::= slice_item (',' slice_item)* [',']
slice_item ::= expression | proper_slice
proper_slice ::= [expression] ':' [expression] [':' [expression] ]
切片列表中包含多个切片项的切片只能用于求值为 Tensor
类型对象的 primary。
调用¶
call ::= primary '(' argument_list ')'
argument_list ::= args [',' kwargs] | kwargs
args ::= [arg (',' arg)*]
kwargs ::= [kwarg (',' kwarg)*]
kwarg ::= arg '=' expression
arg ::= identifier
The primary
必须脱糖或求值为一个可调用对象。所有参数表达式都在尝试调用之前进行求值。
幂运算符¶
power ::= primary ['**' u_expr]
幂运算符的语义与内置的 pow 函数(不支持)相同;它计算左操作数的右操作数次幂。它比左侧的一元运算符绑定更紧密,但比右侧的一元运算符绑定宽松;即 -2 ** -3 == -(2 ** (-3))
。左、右操作数可以是 int
、float
或 Tensor
。在标量-张量/张量-标量指数运算的情况下,标量会被广播,而张量-张量指数运算是按元素进行的,没有任何广播。
一元和算术位运算¶
u_expr ::= power | '-' power | '~' power
一元运算符 -
产生其参数的负数。一元运算符 ~
产生其参数的按位取反。 -
可用于 int
、float
以及 int
和 float
类型的 Tensor
。 ~
只能用于 int
和 int
类型的 Tensor
。
二元算术运算¶
m_expr ::= u_expr | m_expr '*' u_expr | m_expr '@' m_expr | m_expr '//' u_expr | m_expr '/' u_expr | m_expr '%' u_expr
a_expr ::= m_expr | a_expr '+' m_expr | a_expr '-' m_expr
二元算术运算符可对 Tensor
、int
和 float
进行运算。对于张量-张量运算,两个参数必须具有相同的形状。对于标量-张量或张量-标量运算,标量通常会被广播到张量的大小。除法运算只能接受标量作为右侧参数,且不支持广播。@
运算符用于矩阵乘法,并且只对 Tensor
参数进行运算。乘法运算符(*
)可用于列表和整数,以获得将原始列表重复一定次数的结果。
移位运算¶
shift_expr ::= a_expr | shift_expr ( '<<' | '>>' ) a_expr
这些运算符接受两个 int
参数、两个 Tensor
参数,或者一个 Tensor
参数和一个 int
或 float
参数。在所有情况下,右移 n
位定义为向下取整除以 pow(2, n)
,左移 n
位定义为乘以 pow(2, n)
。当两个参数都是 Tensor
时,它们必须具有相同的形状。当一个是标量而另一个是 Tensor
时,标量在逻辑上会被广播以匹配 Tensor
的大小。
二元位运算¶
and_expr ::= shift_expr | and_expr '&' shift_expr
xor_expr ::= and_expr | xor_expr '^' and_expr
or_expr ::= xor_expr | or_expr '|' xor_expr
运算符 &
计算其参数的按位与,^
计算按位异或,|
计算按位或。两个操作数必须是 int
或 Tensor
,或者左操作数是 Tensor
且右操作数是 int
。当两个操作数都是 Tensor
时,它们必须具有相同的形状。当右操作数是 int
且左操作数是 Tensor
时,右操作数在逻辑上会被广播以匹配 Tensor
的形状。
比较¶
comparison ::= or_expr (comp_operator or_expr)*
comp_operator ::= '<' | '>' | '==' | '>=' | '<=' | '!=' | 'is' ['not'] | ['not'] 'in'
比较运算产生一个布尔值(True
或 False
),或者如果其中一个操作数是 Tensor
,则产生一个布尔 Tensor
。比较可以任意链式连接,只要它们产生的布尔 Tensors
不包含多个元素。a op1 b op2 c ...
等价于 a op1 b and b op2 c and ...
。
值比较¶
运算符 <
、>
、==
、>=
、<=
和 !=
比较两个对象的值。这两个对象通常需要是相同的类型,除非它们之间存在隐式类型转换。如果用户自定义类型定义了富比较方法(例如 __lt__
),则可以进行比较。内置类型的比较方式与 Python 相同
数字进行数学比较。
字符串进行字典序比较。
list
、tuple
和dict
只能与同类型的其他list
、tuple
和dict
进行比较,并使用对应元素的比较运算符进行比较。
成员资格测试运算¶
运算符 in
和 not in
测试成员资格。x in s
的求值结果为 True
(如果 x
是 s
的成员)或 False
(否则)。x not in s
等价于 not x in s
。此运算符支持 list
、dict
和 tuple
,如果用户自定义类型实现了 __contains__
方法,也可以使用此运算符。
身份比较¶
对于除 int
、double
、bool
和 torch.device
之外的所有类型,运算符 is
和 is not
测试对象的身份;当且仅当 x
和 y
是同一个对象时,x is y
为 True
。对于所有其他类型,is
等效于使用 ==
进行比较。x is not y
产生 x is y
的反结果。
布尔运算¶
or_test ::= and_test | or_test 'or' and_test
and_test ::= not_test | and_test 'and' not_test
not_test ::= 'bool' '(' or_expr ')' | comparison | 'not' not_test
用户自定义对象可以通过实现 __bool__
方法来自定义其转换为 bool
的方式。如果操作数为 false,则运算符 not
产生 True
,否则产生 False
。表达式 x
and y
首先求值 x
;如果其结果为 False
,则返回其值(False
);否则,求值 y
并返回其值(False
或 True
)。表达式 x
or y
首先求值 x
;如果其结果为 True
,则返回其值(True
);否则,求值 y
并返回其值(False
或 True
)。
条件表达式¶
conditional_expression ::= or_expr ['if' or_test 'else' conditional_expression]
expression ::= conditional_expression
表达式 x if c else y
首先求值条件 c
而不是 x。如果 c
为 True
,则求值 x
并返回其值;否则,求值 y
并返回其值。与 if 语句一样,x
和 y
必须求值为相同类型的值。
表达式列表¶
expression_list ::= expression (',' expression)* [',']
starred_item ::= '*' primary
带星号的项只能出现在赋值语句的左侧,例如 a, *b, c = ...
。
简单语句¶
以下部分描述了 TorchScript 中支持的简单语句的语法。它是参照 Python 语言参考手册的简单语句章节建模的。
表达式语句¶
expression_stmt ::= starred_expression
starred_expression ::= expression | (starred_item ",")* [starred_item]
starred_item ::= assignment_expression | "*" or_expr
赋值语句¶
assignment_stmt ::= (target_list "=")+ (starred_expression)
target_list ::= target ("," target)* [","]
target ::= identifier
| "(" [target_list] ")"
| "[" [target_list] "]"
| attributeref
| subscription
| slicing
| "*" target
增强赋值语句¶
augmented_assignment_stmt ::= augtarget augop (expression_list)
augtarget ::= identifier | attributeref | subscription
augop ::= "+=" | "-=" | "*=" | "/=" | "//=" | "%=" |
"**="| ">>=" | "<<=" | "&=" | "^=" | "|="
注解赋值语句¶
annotated_assignment_stmt ::= augtarget ":" expression
["=" (starred_expression)]
raise
语句¶
raise_stmt ::= "raise" [expression ["from" expression]]
TorchScript 中的 Raise 语句不支持 try\except\finally
。
assert
语句¶
assert_stmt ::= "assert" expression ["," expression]
TorchScript 中的 Assert 语句不支持 try\except\finally
。
return
语句¶
return_stmt ::= "return" [expression_list]
TorchScript 中的 Return 语句不支持 try\except\finally
。
del
语句¶
del_stmt ::= "del" target_list
pass
语句¶
pass_stmt ::= "pass"
print
语句¶
print_stmt ::= "print" "(" expression [, expression] [.format{expression_list}] ")"
break
语句¶
break_stmt ::= "break"
continue
语句:¶
continue_stmt ::= "continue"
复合语句¶
以下部分描述了 TorchScript 中支持的复合语句的语法。本节还强调了 Torchscript 与常规 Python 语句的不同之处。它是参照 Python 语言参考手册的复合语句章节建模的。
if
语句¶
Torchscript 支持基本的 if/else
和三元 if/else
。
基本的 if/else
语句¶
if_stmt ::= "if" assignment_expression ":" suite
("elif" assignment_expression ":" suite)
["else" ":" suite]
elif
语句可以重复任意多次,但必须在 else
语句之前。
三元 if/else
语句¶
if_stmt ::= return [expression_list] "if" assignment_expression "else" [expression_list]
示例 1
一维 tensor
会被提升为 bool
import torch
@torch.jit.script
def fn(x: torch.Tensor):
if x: # The tensor gets promoted to bool
return True
return False
print(fn(torch.rand(1)))
上面的示例产生以下输出
True
示例 2
多维 tensor
不会被提升为 bool
import torch
# Multi dimensional Tensors error out.
@torch.jit.script
def fn():
if torch.rand(2):
print("Tensor is available")
if torch.rand(4,5,6):
print("Tensor is available")
print(fn())
运行上述代码会导致以下 RuntimeError
。
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
@torch.jit.script
def fn():
if torch.rand(2):
~~~~~~~~~~~~ <--- HERE
print("Tensor is available")
RuntimeError: Boolean value of Tensor with more than one value is ambiguous
如果条件变量被注解为 final
,则根据条件变量的求值结果,求值 true 分支或 false 分支。
示例 3
在此示例中,由于 a
被注解为 final
并设置为 True
,因此只求值 True 分支。
import torch
a : torch.jit.final[Bool] = True
if a:
return torch.empty(2,3)
else:
return []
The while
语句¶
while_stmt ::= "while" assignment_expression ":" suite
Torchscript 中不支持 while…else 语句。它会导致 RuntimeError
。
The for-in
语句¶
for_stmt ::= "for" target_list "in" expression_list ":" suite
["else" ":" suite]
Torchscript 中不支持 for...else
语句。它会导致 RuntimeError
。
示例 1
对元组的 For 循环:这些循环会展开(unroll),为元组的每个成员生成一个循环体。循环体必须对每个成员进行正确的类型检查。
import torch
from typing import Tuple
@torch.jit.script
def fn():
tup = (3, torch.ones(4))
for x in tup:
print(x)
fn()
上面的示例产生以下输出
3
1
1
1
1
[ CPUFloatType{4} ]
示例 2
对列表的 For 循环:对 nn.ModuleList
的 For 循环会在编译时展开循环体,对模块列表的每个成员进行处理。
class SubModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.weight = nn.Parameter(torch.randn(2))
def forward(self, input):
return self.weight + input
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.mods = torch.nn.ModuleList([SubModule() for i in range(10)])
def forward(self, v):
for module in self.mods:
v = module(v)
return v
model = torch.jit.script(MyModule())
The with
语句¶
The with
语句用于将一个代码块的执行包装在由上下文管理器定义的方法中。
with_stmt ::= "with" with_item ("," with_item) ":" suite
with_item ::= expression ["as" target]
如果
with
语句中包含目标,上下文管理器的__enter__()
方法的返回值将赋值给它。与 Python 不同的是,如果因异常退出代码块(suite),其类型、值和追溯信息不会作为参数传递给__exit__()
。会提供三个None
参数。try
、except
和finally
语句在with
块内不受支持。在
with
块内引发的异常无法被抑制。
tuple
语句¶
tuple_stmt ::= tuple([iterables])
TorchScript 中的可迭代类型包括
Tensors
、lists
、tuples
、dictionaries
、strings
、torch.nn.ModuleList
和torch.nn.ModuleDict
。你不能使用这个内置函数将 List 转换为 Tuple。
将所有输出解包到 tuple 中
abc = func() # Function that returns a tuple
a,b = func()
getattr
语句¶
getattr_stmt ::= getattr(object, name[, default])
属性名必须是字符串字面值。
不支持模块类型对象(例如,torch._C)。
不支持自定义类对象(例如,torch.classes.*)。
hasattr
语句¶
hasattr_stmt ::= hasattr(object, name)
属性名必须是字符串字面值。
不支持模块类型对象(例如,torch._C)。
不支持自定义类对象(例如,torch.classes.*)。
zip
语句¶
zip_stmt ::= zip(iterable1, iterable2)
参数必须是可迭代对象。
支持两个外部容器类型相同但长度不同的可迭代对象。
示例 1
两个可迭代对象必须是相同的容器类型
a = [1, 2] # List
b = [2, 3, 4] # List
zip(a, b) # works
示例 2
此示例失败是因为可迭代对象是不同的容器类型
a = (1, 2) # Tuple
b = [2, 3, 4] # List
zip(a, b) # Runtime error
运行上述代码会导致以下 RuntimeError
。
RuntimeError: Can not iterate over a module list or
tuple with a value that does not have a statically determinable length.
示例 3
支持两个容器类型相同但数据类型不同的可迭代对象
a = [1.3, 2.4]
b = [2, 3, 4]
zip(a, b) # Works
TorchScript 中的可迭代类型包括 Tensors
、lists
、tuples
、dictionaries
、strings
、torch.nn.ModuleList
和 torch.nn.ModuleDict
。
enumerate
语句¶
enumerate_stmt ::= enumerate([iterable])
参数必须是可迭代对象。
TorchScript 中的可迭代类型包括
Tensors
、lists
、tuples
、dictionaries
、strings
、torch.nn.ModuleList
和torch.nn.ModuleDict
。
Python 值¶
解析规则¶
给定一个 Python 值时,TorchScript 尝试通过以下五种不同的方式解析它
- 可编译的 Python 实现
当 Python 值由一个可以被 TorchScript 编译的 Python 实现支持时,TorchScript 会编译并使用底层的 Python 实现。
示例:
torch.jit.Attribute
- Op Python 封装器
当 Python 值是原生 PyTorch 操作 (op) 的封装器时,TorchScript 会生成对应的操作符。
示例:
torch.jit._logging.add_stat_value
- Python 对象身份匹配
对于 TorchScript 支持的有限集合的
torch.*
API 调用(以 Python 值的形式),TorchScript 尝试将 Python 值与集合中的每个项进行匹配。匹配成功后,TorchScript 会生成对应的
SugaredValue
实例,该实例包含这些值的降低(lowering)逻辑。示例:
torch.jit.isinstance()
- 名称匹配
对于 Python 内置函数和常量,TorchScript 通过名称识别它们,并创建一个实现其功能的相应
SugaredValue
实例。示例:
all()
- 值快照
对于来自未知模块的 Python 值,TorchScript 尝试捕获该值的快照,并将其转换为正在编译的函数或方法图中的常量。
示例:
math.pi
支持的 Python 内置函数¶
内置函数 |
支持级别 |
注释 |
---|---|---|
|
部分 |
仅支持 |
|
完全 |
|
|
完全 |
|
|
无 |
|
|
部分 |
仅支持 |
|
部分 |
仅支持 |
|
无 |
|
|
无 |
|
|
无 |
|
|
无 |
|
|
部分 |
仅支持 ASCII 字符集。 |
|
完全 |
|
|
无 |
|
|
无 |
|
|
无 |
|
|
完全 |
|
|
无 |
|
|
完全 |
|
|
完全 |
|
|
无 |
|
|
无 |
|
|
无 |
|
|
部分 |
不遵循 |
|
部分 |
不支持手动指定索引。| 不支持格式类型修饰符。 |
|
无 |
|
|
部分 |
属性名必须是字符串字面值。 |
|
无 |
|
|
部分 |
属性名必须是字符串字面值。 |
|
完全 |
|
|
部分 |
仅支持 |
|
完全 |
仅支持 |
|
无 |
|
|
部分 |
不支持 |
|
完全 |
在检查 |
|
无 |
|
|
无 |
|
|
完全 |
|
|
完全 |
|
|
部分 |
仅支持 ASCII 字符集。 |
|
完全 |
|
|
部分 |
不支持 |
|
无 |
|
|
完全 |
|
|
无 |
|
|
无 |
|
|
部分 |
不支持 |
|
无 |
|
|
无 |
|
|
完全 |
|
|
部分 |
不支持 |
|
完全 |
|
|
部分 |
不支持 |
|
完全 |
|
|
部分 |
它只能用于 |
|
无 |
|
|
无 |
|
|
完全 |
|
|
无 |
torch.* API¶
远程过程调用¶
TorchScript 支持 RPC API 的一个子集,该子集支持在指定的远程工作节点而非本地运行函数。
具体来说,以下 API 得到完全支持
torch.distributed.rpc.rpc_sync()
rpc_sync()
发起一个阻塞的 RPC 调用,以便在远程工作节点上运行一个函数。RPC 消息与 Python 代码的执行并行地发送和接收。其用法和示例的更多详细信息可在
rpc_sync()
中找到。
torch.distributed.rpc.rpc_async()
rpc_async()
发起一个非阻塞的 RPC 调用,以便在远程工作节点上运行一个函数。RPC 消息与 Python 代码的执行并行地发送和接收。其用法和示例的更多详细信息可在
rpc_async()
中找到。
torch.distributed.rpc.remote()
remote.()
在工作节点上执行远程调用,并获得一个远程引用RRef
作为返回值。其用法和示例的更多详细信息可在
remote()
中找到。
异步执行¶
TorchScript 使你能够创建异步计算任务,以更好地利用计算资源。这是通过支持仅在 TorchScript 内部使用的一系列 API 来实现的
类型标注¶
TorchScript 是静态类型的。它提供并支持一组实用工具,以帮助标注变量和属性。
torch.jit.annotate()
在 Python 3 风格的类型提示不太奏效的情况下,为 TorchScript 提供类型提示。
一个常见的例子是为像
[]
这样的表达式标注类型。默认情况下,[]
被视为List[torch.Tensor]
。当需要不同的类型时,你可以使用此代码来提示 TorchScript:torch.jit.annotate(List[int], [])
。更多详细信息可在
annotate()
中找到
torch.jit.Attribute
常见的用例包括为
torch.nn.Module
属性提供类型提示。由于它们的__init__
方法不被 TorchScript 解析,在模块的__init__
方法中应使用torch.jit.Attribute
而非torch.jit.annotate
。更多详细信息可在
Attribute()
中找到
torch.jit.Final
Python 的
typing.Final
的别名。torch.jit.Final
仅为向后兼容而保留。
元编程¶
TorchScript 提供了一组实用工具来促进元编程。
torch.jit.is_scripting()
返回一个布尔值,指示当前程序是否由
torch.jit.script
编译。当在
assert
或if
语句中使用时,torch.jit.is_scripting()
求值为False
的作用域或分支不会被编译。它的值可以在编译时静态求值,因此常用于
if
语句中,以阻止 TorchScript 编译其中一个分支。更多详细信息和示例可在
is_scripting()
中找到
torch.jit.is_tracing()
返回一个布尔值,指示当前程序是否由
torch.jit.trace
/torch.jit.trace_module
跟踪。更多详细信息可在
is_tracing()
中找到
@torch.jit.ignore
此装饰器指示编译器,一个函数或方法应被忽略,并保留为 Python 函数。
这使得你可以在模型中保留尚未与 TorchScript 兼容的代码。
如果从 TorchScript 调用一个由
@torch.jit.ignore
装饰的函数,被忽略的函数会将调用分派给 Python 解释器。包含被忽略函数的模型无法导出。
更多详细信息和示例可在
ignore()
中找到
@torch.jit.unused
此装饰器指示编译器,一个函数或方法应被忽略,并替换为引发异常。
这使得你可以在模型中保留尚未与 TorchScript 兼容的代码,并且仍然可以导出模型。
如果从 TorchScript 调用一个由
@torch.jit.unused
装饰的函数,将会引发运行时错误。更多详细信息和示例可在
unused()
中找到
类型细化¶
torch.jit.isinstance()
返回一个布尔值,指示变量是否为指定类型。
其用法和示例的更多详细信息可在
isinstance()
中找到。