TorchScript 语言参考¶
本参考手册描述了 TorchScript 语言的语法和核心语义。TorchScript 是 Python 语言的静态类型子集。本文档解释了 TorchScript 中支持的 Python 功能,以及该语言与常规 Python 的不同之处。本参考手册中未提及的任何 Python 功能都不是 TorchScript 的一部分。TorchScript 专门关注表示 PyTorch 中神经网络模型所需的 Python 功能。
术语¶
本文档使用以下术语
模式 |
注释 |
---|---|
|
表示给定的符号定义为。 |
|
表示作为语法一部分的真实关键字和分隔符。 |
|
表示 A 或 B。 |
|
表示分组。 |
|
表示可选。 |
|
表示正则表达式,其中术语 A 至少重复一次。 |
|
表示正则表达式,其中术语 A 重复零次或多次。 |
类型系统¶
TorchScript 是 Python 的静态类型子集。TorchScript 与完整 Python 语言之间最大的区别在于,TorchScript 仅支持一小部分表达神经网络模型所需的类型。
TorchScript 类型¶
TorchScript 类型系统由下面定义的 TSType
和 TSModuleType
组成。
TSAllType ::= TSType | TSModuleType
TSType ::= TSMetaType | TSPrimitiveType | TSStructuralType | TSNominalType
TSType
表示大多数 TorchScript 类型,这些类型是可组合的,并且可以在 TorchScript 类型注解中使用。TSType
指的是以下任何类型
元类型,例如,
Any
原始类型,例如,
int
、float
和str
结构类型,例如,
Optional[int]
或List[MyClass]
名义类型(Python 类),例如,
MyClass
(用户定义)、torch.tensor
(内置)
TSModuleType
表示 torch.nn.Module
及其子类。它与 TSType
的处理方式不同,因为它的类型模式部分是从对象实例推断出来的,部分是从类定义推断出来的。因此,TSModuleType
的实例可能不遵循相同的静态类型模式。TSModuleType
不能用作 TorchScript 类型注解,也不能与 TSType
组合以实现类型安全考虑。
元类型¶
元类型非常抽象,以至于它们更像是类型约束而不是具体类型。目前,TorchScript 定义了一种元类型 Any
,它表示任何 TorchScript 类型。
Any
类型¶
Any
类型表示任何 TorchScript 类型。Any
不指定类型约束,因此不对 Any
进行类型检查。因此,它可以绑定到任何 Python 或 TorchScript 数据类型(例如,int
、TorchScript tuple
或未编写脚本的任意 Python 类)。
TSMetaType ::= "Any"
其中
Any
是来自 typing 模块的 Python 类名称。因此,要使用Any
类型,您必须从typing
导入它(例如,from typing import Any
)。由于
Any
可以表示任何 TorchScript 类型,因此允许对Any
类型的数值进行操作的运算符集是有限的。
为 Any
类型支持的运算符¶
赋值给
Any
类型的数据。绑定到
Any
类型的参数或返回值。x is
、x is not
,其中x
的类型为Any
。isinstance(x, Type)
,其中x
的类型为Any
。Any
类型的数据是可打印的。如果数据是相同类型
T
的值列表,并且T
支持比较运算符,则List[Any]
类型的数据可能是可排序的。
与 Python 相比
Any
是 TorchScript 类型系统中约束最少的类型。从这个意义上讲,它与 Python 中的 Object
类非常相似。但是,Any
仅支持 Object
支持的运算符和方法的子集。
设计注释¶
当我们编写 PyTorch 模块的脚本时,我们可能会遇到未参与脚本执行的数据。然而,它必须由类型模式来描述。为未使用的(在脚本上下文中)数据描述静态类型不仅很麻烦,而且还可能导致不必要的脚本编写失败。Any
的引入是为了描述那些精确静态类型对于编译不是必需的数据类型。
示例 1
此示例说明了如何使用 Any
来允许元组参数的第二个元素为任何类型。这是可能的,因为 x[1]
未参与任何需要知道其精确类型的计算。
import torch
from typing import Tuple
from typing import Any
@torch.jit.export
def inc_first_element(x: Tuple[int, Any]):
return (x[0]+1, x[1])
m = torch.jit.script(inc_first_element)
print(m((1,2.0)))
print(m((1,(100,200))))
上面的示例产生以下输出
(2, 2.0)
(2, (100, 200))
元组的第二个元素是 Any
类型,因此可以绑定到多种类型。例如,(1, 2.0)
将浮点类型绑定到 Any
,如 Tuple[int, Any]
中所示,而 (1, (100, 200))
在第二次调用中将元组绑定到 Any
。
示例 2
此示例说明了如何使用 isinstance
来动态检查注解为 Any
类型的数据类型
import torch
from typing import Any
def f(a:Any):
print(a)
return (isinstance(a, torch.Tensor))
ones = torch.ones([2])
m = torch.jit.script(f)
print(m(ones))
上面的示例产生以下输出
1
1
[ CPUFloatType{2} ]
True
原始类型¶
原始 TorchScript 类型是表示单值类型并带有单个预定义类型名称的类型。
TSPrimitiveType ::= "int" | "float" | "double" | "complex" | "bool" | "str" | "None"
结构类型¶
结构类型是在结构上定义的类型,没有用户定义的名称(与名义类型不同),例如 Future[int]
。结构类型可以与任何 TSType
组合。
TSStructuralType ::= TSTuple | TSNamedTuple | TSList | TSDict |
TSOptional | TSUnion | TSFuture | TSRRef | TSAwait
TSTuple ::= "Tuple" "[" (TSType ",")* TSType "]"
TSNamedTuple ::= "namedtuple" "(" (TSType ",")* TSType ")"
TSList ::= "List" "[" TSType "]"
TSOptional ::= "Optional" "[" TSType "]"
TSUnion ::= "Union" "[" (TSType ",")* TSType "]"
TSFuture ::= "Future" "[" TSType "]"
TSRRef ::= "RRef" "[" TSType "]"
TSAwait ::= "Await" "[" TSType "]"
TSDict ::= "Dict" "[" KeyType "," TSType "]"
KeyType ::= "str" | "int" | "float" | "bool" | TensorType | "Any"
其中
Tuple
、List
、Optional
、Union
、Future
、Dict
表示在模块typing
中定义的 Python 类型类名称。要使用这些类型名称,您必须从typing
导入它们(例如,from typing import Tuple
)。namedtuple
表示 Python 类collections.namedtuple
或typing.NamedTuple
。Future
和RRef
表示 Python 类torch.futures
和torch.distributed.rpc
。Await
表示 Python 类torch._awaits._Await
与 Python 相比
除了可以与 TorchScript 类型组合之外,这些 TorchScript 结构类型通常还支持其 Python 对应项的常用运算符和方法的子集。
示例 1
此示例使用 typing.NamedTuple
语法来定义元组
import torch
from typing import NamedTuple
from typing import Tuple
class MyTuple(NamedTuple):
first: int
second: int
def inc(x: MyTuple) -> Tuple[int, int]:
return (x.first+1, x.second+1)
t = MyTuple(first=1, second=2)
scripted_inc = torch.jit.script(inc)
print("TorchScript:", scripted_inc(t))
上面的示例产生以下输出
TorchScript: (2, 3)
示例 2
此示例使用 collections.namedtuple
语法来定义元组
import torch
from typing import NamedTuple
from typing import Tuple
from collections import namedtuple
_AnnotatedNamedTuple = NamedTuple('_NamedTupleAnnotated', [('first', int), ('second', int)])
_UnannotatedNamedTuple = namedtuple('_NamedTupleAnnotated', ['first', 'second'])
def inc(x: _AnnotatedNamedTuple) -> Tuple[int, int]:
return (x.first+1, x.second+1)
m = torch.jit.script(inc)
print(inc(_UnannotatedNamedTuple(1,2)))
上面的示例产生以下输出
(2, 3)
示例 3
此示例说明了注解结构类型的常见错误,即未从 typing
模块导入复合类型类
import torch
# ERROR: Tuple not recognized because not imported from typing
@torch.jit.export
def inc(x: Tuple[int, int]):
return (x[0]+1, x[1]+1)
m = torch.jit.script(inc)
print(m((1,2)))
运行上面的代码会产生以下脚本编写错误
File "test-tuple.py", line 5, in <module>
def inc(x: Tuple[int, int]):
NameError: name 'Tuple' is not defined
补救措施是在代码的开头添加行 from typing import Tuple
。
名义类型¶
名义 TorchScript 类型是 Python 类。这些类型之所以称为名义类型,是因为它们是用自定义名称声明的,并且使用类名称进行比较。名义类进一步分为以下几类
TSNominalType ::= TSBuiltinClasses | TSCustomClass | TSEnum
其中,TSCustomClass
和 TSEnum
必须可编译为 TorchScript 中间表示 (IR)。这是由类型检查器强制执行的。
内置类¶
内置名义类型是语义内置于 TorchScript 系统中的 Python 类(例如,张量类型)。TorchScript 定义了这些内置名义类型的语义,并且通常仅支持其 Python 类定义的方法或属性的子集。
TSBuiltinClass ::= TSTensor | "torch.device" | "torch.Stream" | "torch.dtype" |
"torch.nn.ModuleList" | "torch.nn.ModuleDict" | ...
TSTensor ::= "torch.Tensor" | "common.SubTensor" | "common.SubWithTorchFunction" |
"torch.nn.parameter.Parameter" | and subclasses of torch.Tensor
关于 torch.nn.ModuleList 和 torch.nn.ModuleDict 的特别说明¶
虽然 torch.nn.ModuleList
和 torch.nn.ModuleDict
在 Python 中被定义为列表和字典,但它们在 TorchScript 中的行为更像元组
在 TorchScript 中,
torch.nn.ModuleList
或torch.nn.ModuleDict
的实例是不可变的。迭代
torch.nn.ModuleList
或torch.nn.ModuleDict
的代码被完全展开,以便torch.nn.ModuleList
的元素或torch.nn.ModuleDict
的键可以是torch.nn.Module
的不同子类。
示例
以下示例重点介绍了几个内置 Torchscript 类 (torch.*
) 的用法
import torch
@torch.jit.script
class A:
def __init__(self):
self.x = torch.rand(3)
def f(self, y: torch.device):
return self.x.to(device=y)
def g():
a = A()
return a.f(torch.device("cpu"))
script_g = torch.jit.script(g)
print(script_g.graph)
自定义类¶
与内置类不同,自定义类的语义是用户定义的,并且整个类定义必须可编译为 TorchScript IR 并受 TorchScript 类型检查规则的约束。
TSClassDef ::= [ "@torch.jit.script" ]
"class" ClassName [ "(object)" ] ":"
MethodDefinition |
[ "@torch.jit.ignore" ] | [ "@torch.jit.unused" ]
MethodDefinition
其中
类必须是新式类。Python 3 仅支持新式类。在 Python 2.x 中,通过从对象子类化来指定新式类。
实例数据属性是静态类型的,并且必须通过在
__init__()
方法内部进行赋值来声明实例属性。不支持方法重载(即,您不能有多个具有相同方法名称的方法)。
MethodDefinition
必须可编译为 TorchScript IR 并遵守 TorchScript 的类型检查规则,(即,所有方法都必须是有效的 TorchScript 函数,并且类属性定义必须是有效的 TorchScript 语句)。torch.jit.ignore
和torch.jit.unused
可用于忽略未完全可 TorchScript 化或应被编译器忽略的方法或函数。
与 Python 相比
与 Python 对应项相比,TorchScript 自定义类受到很大限制。Torchscript 自定义类
不支持类属性。
不支持子类化,除非子类化接口类型或对象。
不支持方法重载。
必须在
__init__()
中初始化其所有实例属性;这是因为 TorchScript 通过推断__init__()
中的属性类型来构造类的静态模式。必须仅包含满足 TorchScript 类型检查规则并且可编译为 TorchScript IR 的方法。
示例 1
如果 Python 类使用 @torch.jit.script
进行注解,则可以在 TorchScript 中使用它们,类似于声明 TorchScript 函数的方式
@torch.jit.script
class MyClass:
def __init__(self, x: int):
self.x = x
def inc(self, val: int):
self.x += val
示例 2
TorchScript 自定义类类型必须通过在 __init__()
中赋值来“声明”其所有实例属性。如果实例属性未在 __init__()
中定义,但在类的其他方法中访问,则该类无法编译为 TorchScript 类,如下例所示
import torch
@torch.jit.script
class foo:
def __init__(self):
self.y = 1
# ERROR: self.x is not defined in __init__
def assign_x(self):
self.x = torch.rand(2, 3)
该类将无法编译并发出以下错误
RuntimeError:
Tried to set nonexistent attribute: x. Did you forget to initialize it in __init__()?:
def assign_x(self):
self.x = torch.rand(2, 3)
~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
示例 3
在此示例中,TorchScript 自定义类定义了一个类变量名称,这是不允许的
import torch
@torch.jit.script
class MyClass(object):
name = "MyClass"
def __init__(self, x: int):
self.x = x
def fn(a: MyClass):
return a.name
这会导致以下编译时错误
RuntimeError:
'__torch__.MyClass' object has no attribute or method 'name'. Did you forget to initialize an attribute in __init__()?:
File "test-class2.py", line 10
def fn(a: MyClass):
return a.name
~~~~~~ <--- HERE
枚举类型¶
与自定义类一样,枚举类型的语义是用户定义的,并且整个类定义必须可编译为 TorchScript IR 并遵守 TorchScript 类型检查规则。
TSEnumDef ::= "class" Identifier "(enum.Enum | TSEnumType)" ":"
( MemberIdentifier "=" Value )+
( MethodDefinition )*
其中
值必须是
int
、float
或str
类型的 TorchScript 字面量,并且必须是相同的 TorchScript 类型。TSEnumType
是 TorchScript 枚举类型的名称。与 Python 枚举类似,TorchScript 允许受限的Enum
子类化,也就是说,仅当枚举类未定义任何成员时才允许子类化。
与 Python 相比
TorchScript 仅支持
enum.Enum
。它不支持其他变体,例如enum.IntEnum
、enum.Flag
、enum.IntFlag
和enum.auto
。TorchScript 枚举成员的值必须是相同的类型,并且只能是
int
、float
或str
类型,而 Python 枚举成员可以是任何类型。包含方法的枚举在 TorchScript 中被忽略。
示例 1
以下示例将类 Color
定义为 Enum
类型
import torch
from enum import Enum
class Color(Enum):
RED = 1
GREEN = 2
def enum_fn(x: Color, y: Color) -> bool:
if x == Color.RED:
return True
return x == y
m = torch.jit.script(enum_fn)
print("Eager: ", enum_fn(Color.RED, Color.GREEN))
print("TorchScript: ", m(Color.RED, Color.GREEN))
示例 2
以下示例显示了受限枚举子类化的情况,其中 BaseColor
未定义任何成员,因此可以被 Color
子类化
import torch
from enum import Enum
class BaseColor(Enum):
def foo(self):
pass
class Color(BaseColor):
RED = 1
GREEN = 2
def enum_fn(x: Color, y: Color) -> bool:
if x == Color.RED:
return True
return x == y
m = torch.jit.script(enum_fn)
print("TorchScript: ", m(Color.RED, Color.GREEN))
print("Eager: ", enum_fn(Color.RED, Color.GREEN))
TorchScript 模块类¶
TSModuleType
是一种特殊的类类型,它是从 TorchScript 外部创建的对象实例推断出来的。TSModuleType
以对象实例的 Python 类命名。Python 类的 __init__()
方法不被视为 TorchScript 方法,因此它不必遵守 TorchScript 的类型检查规则。
模块实例类的类型模式直接从实例对象(在 TorchScript 范围之外创建)构建,而不是像自定义类那样从 __init__()
推断。同一实例类类型的两个对象可能遵循两种不同的类型模式。
从这个意义上讲,TSModuleType
并不是真正的静态类型。因此,出于类型安全考虑,TSModuleType
不能在 TorchScript 类型注解中使用,也不能与 TSType
组合。
模块实例类¶
TorchScript 模块类型表示用户定义的 PyTorch 模块实例的类型模式。在编写 PyTorch 模块的脚本时,模块对象始终在 TorchScript 外部创建(即,作为参数传递到 forward
)。Python 模块类被视为模块实例类,因此 Python 模块类的 __init__()
方法不受 TorchScript 的类型检查规则的约束。
TSModuleType ::= "class" Identifier "(torch.nn.Module)" ":"
ClassBodyDefinition
其中
forward()
和使用@torch.jit.export
修饰的其他方法必须可编译为 TorchScript IR 并受 TorchScript 的类型检查规则的约束。
与自定义类不同,只有模块类型的 forward 方法和使用 @torch.jit.export
修饰的其他方法需要是可编译的。最值得注意的是,__init__()
不被视为 TorchScript 方法。因此,模块类型构造函数不能在 TorchScript 范围内调用。相反,TorchScript 模块对象始终在外部构造并传递到 torch.jit.script(ModuleObj)
。
示例 1
此示例说明了模块类型的几个功能
TestModule
实例在 TorchScript 作用域之外创建(即,在调用torch.jit.script
之前)。__init__()
不被视为 TorchScript 方法,因此,它不必进行注解,并且可以包含任意 Python 代码。 此外,不能在 TorchScript 代码中调用实例类的__init__()
方法。 因为TestModule
实例是在 Python 中实例化的,所以在本例中,TestModule(2.0)
和TestModule(2)
创建了两个实例,它们的数据属性类型不同。 对于TestModule(2.0)
,self.x
的类型为float
,而对于TestModule(2.0)
,self.y
的类型为int
。TorchScript 自动编译通过
@torch.jit.export
或forward()
方法注解调用的其他方法(例如,mul()
)。TorchScript 程序的入口点可以是模块类型的
forward()
,注解为torch.jit.script
的函数,或注解为torch.jit.export
的方法。
import torch
class TestModule(torch.nn.Module):
def __init__(self, v):
super().__init__()
self.x = v
def forward(self, inc: int):
return self.x + inc
m = torch.jit.script(TestModule(1))
print(f"First instance: {m(3)}")
m = torch.jit.script(TestModule(torch.ones([5])))
print(f"Second instance: {m(3)}")
上面的示例产生以下输出
First instance: 4
Second instance: tensor([4., 4., 4., 4., 4.])
示例 2
以下示例显示了模块类型的不正确用法。 具体来说,此示例在 TorchScript 作用域内调用了 TestModule
的构造函数
import torch
class TestModule(torch.nn.Module):
def __init__(self, v):
super().__init__()
self.x = v
def forward(self, x: int):
return self.x + x
class MyModel:
def __init__(self, v: int):
self.val = v
@torch.jit.export
def doSomething(self, val: int) -> int:
# error: should not invoke the constructor of module type
myModel = TestModule(self.val)
return myModel(val)
# m = torch.jit.script(MyModel(2)) # Results in below RuntimeError
# RuntimeError: Could not get name of python class object
类型注解¶
由于 TorchScript 是静态类型化的,程序员需要在 TorchScript 代码的策略性位置注解类型,以便每个局部变量或实例数据属性都具有静态类型,并且每个函数和方法都具有静态类型化的签名。
何时注解类型¶
通常,只有在无法自动推断静态类型的地方才需要类型注解(例如,方法或函数的参数或有时是返回类型)。 局部变量和数据属性的类型通常从它们的赋值语句中自动推断出来。 有时,推断的类型可能过于严格,例如,通过赋值 x = None
推断 x
为 NoneType
,而 x
实际上用作 Optional
。 在这种情况下,可能需要类型注解来覆盖自动推断,例如,x: Optional[int] = None
。 请注意,即使可以自动推断局部变量或数据属性的类型,对其进行类型注解也始终是安全的。 注解的类型必须与 TorchScript 的类型检查一致。
当参数、局部变量或数据属性未进行类型注解且其类型无法自动推断时,TorchScript 假定其为默认类型 TensorType
、List[TensorType]
或 Dict[str, TensorType]
。
注解函数签名¶
由于参数可能无法从函数体(包括函数和方法)中自动推断出来,因此需要对它们进行类型注解。 否则,它们将假定默认类型 TensorType
。
TorchScript 支持两种方法和函数签名类型注解的风格
Python3 风格 直接在签名上注解类型。 这样,它允许单独的参数保持未注解状态(其类型将为默认类型
TensorType
),或者允许返回类型保持未注解状态(其类型将自动推断)。
Python3Annotation ::= "def" Identifier [ "(" ParamAnnot* ")" ] [ReturnAnnot] ":"
FuncOrMethodBody
ParamAnnot ::= Identifier [ ":" TSType ] ","
ReturnAnnot ::= "->" TSType
请注意,使用 Python3 风格时,类型 self
会自动推断,不应进行注解。
Mypy 风格 将类型注解为函数/方法声明正下方的注释。 在 Mypy 风格中,由于参数名称未出现在注解中,因此必须注解所有参数。
MyPyAnnotation ::= "# type:" "(" ParamAnnot* ")" [ ReturnAnnot ]
ParamAnnot ::= TSType ","
ReturnAnnot ::= "->" TSType
示例 1
在本例中
a
未注解,并假定为默认类型TensorType
。b
注解为类型int
。返回类型未注解,并自动推断为类型
TensorType
(基于返回值的类型)。
import torch
def f(a, b: int):
return a+b
m = torch.jit.script(f)
print("TorchScript:", m(torch.ones([6]), 100))
示例 2
以下示例使用 Mypy 风格注解。 请注意,即使某些参数或返回值假定为默认类型,也必须对其进行注解。
import torch
def f(a, b):
# type: (torch.Tensor, int) → torch.Tensor
return a+b
m = torch.jit.script(f)
print("TorchScript:", m(torch.ones([6]), 100))
注解变量和数据属性¶
通常,数据属性(包括类和实例数据属性)和局部变量的类型可以从赋值语句中自动推断出来。 然而,有时,如果变量或属性与不同类型的值关联(例如,None
或 TensorType
),则可能需要将其显式类型注解为更宽泛的类型,例如 Optional[int]
或 Any
。
局部变量¶
局部变量可以根据 Python3 typing 模块注解规则进行注解,即:
LocalVarAnnotation ::= Identifier [":" TSType] "=" Expr
通常,可以自动推断局部变量的类型。 但是,在某些情况下,您可能需要为可能与不同具体类型关联的局部变量注解多类型。 典型的多类型包括 Optional[T]
和 Any
。
示例
import torch
def f(a, setVal: bool):
value: Optional[torch.Tensor] = None
if setVal:
value = a
return value
ones = torch.ones([6])
m = torch.jit.script(f)
print("TorchScript:", m(ones, True), m(ones, False))
实例数据属性¶
对于 ModuleType
类,实例数据属性可以根据 Python3 typing 模块注解规则进行注解。 实例数据属性可以(可选地)通过 Final
注解为 final。
"class" ClassIdentifier "(torch.nn.Module):"
InstanceAttrIdentifier ":" ["Final("] TSType [")"]
...
其中
InstanceAttrIdentifier
是实例属性的名称。Final
表示该属性不能在__init__
之外重新分配或在子类中重写。
示例
import torch
class MyModule(torch.nn.Module):
offset_: int
def __init__(self, offset):
self.offset_ = offset
...
类型注解 API¶
torch.jit.annotate(T, expr)
¶
此 API 将类型 T
注解到表达式 expr
。 当表达式的默认类型不是程序员预期的类型时,通常会使用此 API。 例如,空列表(字典)的默认类型为 List[TensorType]
(Dict[TensorType, TensorType]
),但有时它可能用于初始化某种其他类型的列表。 另一个常见的用例是注解 tensor.tolist()
的返回类型。 但是请注意,它不能用于注解 __init__ 中的模块属性; 应该改用 torch.jit.Attribute
。
示例
在本例中,[]
通过 torch.jit.annotate
声明为整数列表(而不是假定 []
为默认类型 List[TensorType]
)
import torch
from typing import List
def g(l: List[int], val: int):
l.append(val)
return l
def f(val: int):
l = g(torch.jit.annotate(List[int], []), val)
return l
m = torch.jit.script(f)
print("Eager:", f(3))
print("TorchScript:", m(3))
有关更多信息,请参见 torch.jit.annotate()
。
类型注解附录¶
TorchScript 类型系统定义¶
TSAllType ::= TSType | TSModuleType
TSType ::= TSMetaType | TSPrimitiveType | TSStructuralType | TSNominalType
TSMetaType ::= "Any"
TSPrimitiveType ::= "int" | "float" | "double" | "complex" | "bool" | "str" | "None"
TSStructuralType ::= TSTuple | TSNamedTuple | TSList | TSDict | TSOptional |
TSUnion | TSFuture | TSRRef | TSAwait
TSTuple ::= "Tuple" "[" (TSType ",")* TSType "]"
TSNamedTuple ::= "namedtuple" "(" (TSType ",")* TSType ")"
TSList ::= "List" "[" TSType "]"
TSOptional ::= "Optional" "[" TSType "]"
TSUnion ::= "Union" "[" (TSType ",")* TSType "]"
TSFuture ::= "Future" "[" TSType "]"
TSRRef ::= "RRef" "[" TSType "]"
TSAwait ::= "Await" "[" TSType "]"
TSDict ::= "Dict" "[" KeyType "," TSType "]"
KeyType ::= "str" | "int" | "float" | "bool" | TensorType | "Any"
TSNominalType ::= TSBuiltinClasses | TSCustomClass | TSEnum
TSBuiltinClass ::= TSTensor | "torch.device" | "torch.stream"|
"torch.dtype" | "torch.nn.ModuleList" |
"torch.nn.ModuleDict" | ...
TSTensor ::= "torch.tensor" and subclasses
不支持的类型构造¶
TorchScript 不支持 Python3 typing 模块的所有特性和类型。 typing 模块中任何未在本文档中明确指定的功能均不受支持。 下表总结了 typing
构造,这些构造在 TorchScript 中不受支持或受到限制。
条目 |
描述 |
|
开发中 |
|
不支持 |
|
不支持 |
|
不支持 |
|
不支持 |
|
支持模块属性、类属性和注解,但不支持函数。 |
|
不支持 |
|
开发中 |
类型别名 |
不支持 |
标称类型 |
开发中 |
结构类型 |
不支持 |
NewType |
不支持 |
泛型 |
不支持 |
表达式¶
以下部分描述了 TorchScript 中支持的表达式的语法。 它以 Python 语言参考的表达式章节为蓝本。
算术转换¶
TorchScript 中执行了许多隐式类型转换
数据类型为
float
或int
的Tensor
可以隐式转换为FloatType
或IntType
的实例,前提是它的大小为 0,未将require_grad
设置为True
,并且不需要缩小。StringType
的实例可以隐式转换为DeviceType
。以上两点中的隐式转换规则可以应用于
TupleType
的实例,以生成具有适当包含类型的ListType
的实例。
可以使用内置函数 float
、int
、bool
和 str
调用显式转换,这些函数接受原始数据类型作为参数,如果用户定义的类型实现了 __bool__
、__str__
等,则也可以接受用户定义的类型。
原子¶
原子是表达式的最基本元素。
atom ::= identifier | literal | enclosure
enclosure ::= parenth_form | list_display | dict_display
标识符¶
在 TorchScript 中,决定什么是合法标识符的规则与其 Python 对应物相同。
字面量¶
literal ::= stringliteral | integer | floatnumber
字面量的求值会产生具有特定值的适当类型的对象(并根据需要对浮点数应用近似值)。 字面量是不可变的,对相同字面量的多次求值可能会获得相同的对象或具有相同值的不同对象。 stringliteral、integer 和 floatnumber 的定义方式与其 Python 对应物相同。
带括号的形式¶
parenth_form ::= '(' [expression_list] ')'
带括号的表达式列表产生表达式列表产生的任何内容。 如果列表至少包含一个逗号,则它产生一个 Tuple
; 否则,它产生表达式列表内的单个表达式。 空括号对产生一个空的 Tuple
对象 (Tuple[]
)。
列表和字典显示¶
list_comprehension ::= expression comp_for
comp_for ::= 'for' target_list 'in' or_expr
list_display ::= '[' [expression_list | list_comprehension] ']'
dict_display ::= '{' [key_datum_list | dict_comprehension] '}'
key_datum_list ::= key_datum (',' key_datum)*
key_datum ::= expression ':' expression
dict_comprehension ::= key_datum comp_for
列表和字典可以通过显式列出容器内容或通过提供有关如何通过一组循环指令(即推导式)计算它们的指令来构造。 推导式在语义上等效于使用 for 循环并附加到正在进行的列表中。 推导式隐式创建自己的作用域,以确保目标列表的项目不会泄漏到封闭作用域中。 如果在具有 key_datum_list
的 dict_display
中重复了键,则结果字典使用列表中使用重复键的最右侧数据中的值。
主要项¶
primary ::= atom | attributeref | subscription | slicing | call
下标¶
subscription ::= primary '[' expression_list ']'
primary
必须求值为支持下标的对象。
如果 primary 是
List
、Tuple
或str
,则表达式列表必须求值为整数或切片。如果 primary 是
Dict
,则表达式列表必须求值为与Dict
的键类型相同的类型的对象。如果 primary 是
ModuleList
,则表达式列表必须是integer
字面量。如果 primary 是
ModuleDict
,则表达式必须是stringliteral
。
切片¶
切片选择 str
、Tuple
、List
或 Tensor
中的一系列项目。 切片可以用作赋值或 del
语句中的表达式或目标。
slicing ::= primary '[' slice_list ']'
slice_list ::= slice_item (',' slice_item)* [',']
slice_item ::= expression | proper_slice
proper_slice ::= [expression] ':' [expression] [':' [expression] ]
切片列表中具有多个切片项的切片只能与求值为 Tensor
类型对象的 primary 一起使用。
调用¶
call ::= primary '(' argument_list ')'
argument_list ::= args [',' kwargs] | kwargs
args ::= [arg (',' arg)*]
kwargs ::= [kwarg (',' kwarg)*]
kwarg ::= arg '=' expression
arg ::= identifier
primary
必须反糖化或求值为可调用对象。 所有参数表达式都在尝试调用之前求值。
幂运算符¶
power ::= primary ['**' u_expr]
幂运算符具有与内置 pow 函数(不支持)相同的语义; 它计算其左参数的右参数次方。 它比左侧的一元运算符绑定得更紧密,但比右侧的一元运算符绑定得更松散; 即 -2 ** -3 == -(2 ** (-3))
。 左操作数和右操作数可以是 int
、float
或 Tensor
。 在标量-张量/张量-标量求幂运算的情况下,标量会广播,而张量-张量求幂运算是逐元素完成的,没有任何广播。
一元和算术按位运算¶
u_expr ::= power | '-' power | '~' power
一元 -
运算符产生其参数的取反。 一元 ~
运算符产生其参数的按位反转。 -
可以与 int
、float
和 int
和 float
的 Tensor
一起使用。 ~
只能与 int
和 int
的 Tensor
一起使用。
二元算术运算¶
m_expr ::= u_expr | m_expr '*' u_expr | m_expr '@' m_expr | m_expr '//' u_expr | m_expr '/' u_expr | m_expr '%' u_expr
a_expr ::= m_expr | a_expr '+' m_expr | a_expr '-' m_expr
二元算术运算符可以对 Tensor
、int
和 float
进行运算。 对于张量-张量运算,两个参数必须具有相同的形状。 对于标量-张量或张量-标量运算,标量通常广播到张量的大小。 除法运算只能接受标量作为其右手参数,并且不支持广播。 @
运算符用于矩阵乘法,并且仅对 Tensor
参数进行运算。 乘法运算符 (*
) 可以与列表和整数一起使用,以便获得将原始列表重复一定次数的结果。
移位运算¶
shift_expr ::= a_expr | shift_expr ( '<<' | '>>' ) a_expr
这些运算符接受两个 int
参数、两个 Tensor
参数,或者一个 Tensor
参数和一个 int
或 float
参数。 在所有情况下,右移 n
定义为除以 pow(2, n)
的向下取整除法,左移 n
定义为乘以 pow(2, n)
。 当两个参数都是 Tensor
时,它们必须具有相同的形状。 当一个是标量,另一个是 Tensor
时,标量在逻辑上会广播以匹配 Tensor
的大小。
二元按位运算¶
and_expr ::= shift_expr | and_expr '&' shift_expr
xor_expr ::= and_expr | xor_expr '^' and_expr
or_expr ::= xor_expr | or_expr '|' xor_expr
&
运算符计算其参数的按位与,^
计算按位异或,|
计算按位或。 两个操作数都必须是 int
或 Tensor
,或者左操作数必须是 Tensor
,右操作数必须是 int
。 当两个操作数都是 Tensor
时,它们必须具有相同的形状。 当右操作数是 int
,左操作数是 Tensor
时,右操作数在逻辑上会广播以匹配 Tensor
的形状。
比较¶
comparison ::= or_expr (comp_operator or_expr)*
comp_operator ::= '<' | '>' | '==' | '>=' | '<=' | '!=' | 'is' ['not'] | ['not'] 'in'
比较产生一个布尔值 (True
或 False
),或者如果其中一个操作数是 Tensor
,则产生一个布尔 Tensor
。 只要比较不产生具有多个元素的布尔 Tensor
,就可以任意链接比较。 a op1 b op2 c ...
等效于 a op1 b and b op2 c and ...
。
值比较¶
运算符 <
、>
、==
、>=
、<=
和 !=
比较两个对象的值。 这两个对象通常需要是相同类型的,除非对象之间存在隐式类型转换。 如果用户定义的类型上定义了丰富的比较方法(例如,__lt__
),则可以比较用户定义的类型。 内置类型比较的工作方式类似于 Python
数字在数学上进行比较。
字符串按字典顺序比较。
lists
、tuples
和dicts
只能与相同类型的其他lists
、tuples
和dicts
进行比较,并且使用相应元素的比较运算符进行比较。
成员资格测试运算¶
运算符 in
和 not in
测试成员资格。 如果 x
是 s
的成员,则 x in s
的求值结果为 True
,否则为 False
。 x not in s
等效于 not x in s
。 此运算符支持 lists
、dicts
和 tuples
,如果用户定义的类型实现了 __contains__
方法,则可以与用户定义的类型一起使用。
身份比较¶
对于除 int
、double
、bool
和 torch.device
之外的所有类型,运算符 is
和 is not
测试对象的身份; 当且仅当 x
和 y
是同一个对象时,x is y
才为 True
。 对于所有其他类型,is
等效于使用 ==
比较它们。 x is not y
产生 x is y
的反值。
布尔运算¶
or_test ::= and_test | or_test 'or' and_test
and_test ::= not_test | and_test 'and' not_test
not_test ::= 'bool' '(' or_expr ')' | comparison | 'not' not_test
用户定义的对象可以通过实现 __bool__
方法来自定义其到 bool
的转换。 运算符 not
在其操作数为假时产生 True
,否则产生 False
。 表达式 x and y
首先求值 x
; 如果为 False
,则返回其值 (False
); 否则,求值 y
并返回其值 (False
或 True
)。 表达式 x or y
首先求值 x
; 如果为 True
,则返回其值 (True
); 否则,求值 y
并返回其值 (False
或 True
)。
条件表达式¶
conditional_expression ::= or_expr ['if' or_test 'else' conditional_expression]
expression ::= conditional_expression
表达式 x if c else y
首先求值条件 c
而不是 x。 如果 c
为 True
,则求值 x
并返回其值; 否则,求值 y
并返回其值。 与 if 语句一样,x
和 y
必须求值为相同类型的值。
表达式列表¶
expression_list ::= expression (',' expression)* [',']
starred_item ::= '*' primary
带星号的项目只能出现在赋值语句的左侧,例如,a, *b, c = ...
。
简单语句¶
以下部分描述了 TorchScript 中支持的简单语句的语法。 它以 Python 语言参考的简单语句章节为蓝本。
表达式语句¶
expression_stmt ::= starred_expression
starred_expression ::= expression | (starred_item ",")* [starred_item]
starred_item ::= assignment_expression | "*" or_expr
赋值语句¶
assignment_stmt ::= (target_list "=")+ (starred_expression)
target_list ::= target ("," target)* [","]
target ::= identifier
| "(" [target_list] ")"
| "[" [target_list] "]"
| attributeref
| subscription
| slicing
| "*" target
增强赋值语句¶
augmented_assignment_stmt ::= augtarget augop (expression_list)
augtarget ::= identifier | attributeref | subscription
augop ::= "+=" | "-=" | "*=" | "/=" | "//=" | "%=" |
"**="| ">>=" | "<<=" | "&=" | "^=" | "|="
注解赋值语句¶
annotated_assignment_stmt ::= augtarget ":" expression
["=" (starred_expression)]
raise
语句¶
raise_stmt ::= "raise" [expression ["from" expression]]
TorchScript 中的 raise 语句不支持 try\except\finally
。
assert
语句¶
assert_stmt ::= "assert" expression ["," expression]
TorchScript 中的 assert 语句不支持 try\except\finally
。
return
语句¶
return_stmt ::= "return" [expression_list]
TorchScript 中的 return 语句不支持 try\except\finally
。
del
语句¶
del_stmt ::= "del" target_list
pass
语句¶
pass_stmt ::= "pass"
print
语句¶
print_stmt ::= "print" "(" expression [, expression] [.format{expression_list}] ")"
break
语句¶
break_stmt ::= "break"
continue
语句:¶
continue_stmt ::= "continue"
复合语句¶
以下部分描述了 TorchScript 中支持的复合语句的语法。 本节还重点介绍了 Torchscript 与常规 Python 语句的不同之处。 它以 Python 语言参考的复合语句章节为蓝本。
if
语句¶
Torchscript 支持基本 if/else
和三元 if/else
。
基本 if/else
语句¶
if_stmt ::= "if" assignment_expression ":" suite
("elif" assignment_expression ":" suite)
["else" ":" suite]
elif
语句可以重复任意次数,但必须在 else
语句之前。
三元 if/else
语句¶
if_stmt ::= return [expression_list] "if" assignment_expression "else" [expression_list]
示例 1
具有 1 个维度的 tensor
会提升为 bool
import torch
@torch.jit.script
def fn(x: torch.Tensor):
if x: # The tensor gets promoted to bool
return True
return False
print(fn(torch.rand(1)))
上面的示例产生以下输出
True
示例 2
一个带有多个维度的 tensor
不会被提升为 bool
类型。
import torch
# Multi dimensional Tensors error out.
@torch.jit.script
def fn():
if torch.rand(2):
print("Tensor is available")
if torch.rand(4,5,6):
print("Tensor is available")
print(fn())
运行上述代码会产生以下 RuntimeError
。
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
@torch.jit.script
def fn():
if torch.rand(2):
~~~~~~~~~~~~ <--- HERE
print("Tensor is available")
RuntimeError: Boolean value of Tensor with more than one value is ambiguous
如果条件变量被注解为 final
,则会根据条件变量的求值结果评估 true 分支或 false 分支。
示例 3
在此示例中,仅评估 True 分支,因为 a
被注解为 final
并设置为 True
。
import torch
a : torch.jit.final[Bool] = True
if a:
return torch.empty(2,3)
else:
return []
while
语句¶
while_stmt ::= "while" assignment_expression ":" suite
Torchscript 中不支持 while…else 语句。它会导致 RuntimeError
。
for-in
语句¶
for_stmt ::= "for" target_list "in" expression_list ":" suite
["else" ":" suite]
Torchscript 中不支持 for...else
语句。它会导致 RuntimeError
。
示例 1
元组上的 For 循环:这些循环会展开,为元组的每个成员生成一个主体。主体必须为每个成员正确进行类型检查。
import torch
from typing import Tuple
@torch.jit.script
def fn():
tup = (3, torch.ones(4))
for x in tup:
print(x)
fn()
上面的示例产生以下输出
3
1
1
1
1
[ CPUFloatType{4} ]
示例 2
列表上的 For 循环:对 nn.ModuleList
的 for 循环将在编译时展开循环主体,每个成员对应模块列表中的一个成员。
class SubModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.weight = nn.Parameter(torch.randn(2))
def forward(self, input):
return self.weight + input
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.mods = torch.nn.ModuleList([SubModule() for i in range(10)])
def forward(self, v):
for module in self.mods:
v = module(v)
return v
model = torch.jit.script(MyModule())
with
语句¶
with
语句用于使用上下文管理器定义的方法包装代码块的执行。
with_stmt ::= "with" with_item ("," with_item) ":" suite
with_item ::= expression ["as" target]
如果在
with
语句中包含目标,则上下文管理器的__enter__()
的返回值将分配给它。与 python 不同,如果异常导致套件退出,则其类型、值和回溯不会作为参数传递给__exit__()
。将提供三个None
参数。try
、except
和finally
语句在with
代码块内部不受支持。无法抑制在
with
代码块内引发的异常。
tuple
语句¶
tuple_stmt ::= tuple([iterables])
TorchScript 中的可迭代类型包括
Tensors
、lists
、tuples
、dictionaries
、strings
、torch.nn.ModuleList
和torch.nn.ModuleDict
。您无法使用此内置函数将 List 转换为 Tuple。
将所有输出解包到元组中已包含在内
abc = func() # Function that returns a tuple
a,b = func()
getattr
语句¶
getattr_stmt ::= getattr(object, name[, default])
属性名称必须是文字字符串。
不支持模块类型对象(例如,torch._C)。
不支持自定义类对象(例如,torch.classes.*)。
hasattr
语句¶
hasattr_stmt ::= hasattr(object, name)
属性名称必须是文字字符串。
不支持模块类型对象(例如,torch._C)。
不支持自定义类对象(例如,torch.classes.*)。
zip
语句¶
zip_stmt ::= zip(iterable1, iterable2)
参数必须是可迭代对象。
支持具有相同外部容器类型但长度不同的两个可迭代对象。
示例 1
两个可迭代对象必须是相同的容器类型
a = [1, 2] # List
b = [2, 3, 4] # List
zip(a, b) # works
示例 2
此示例失败,因为可迭代对象是不同的容器类型
a = (1, 2) # Tuple
b = [2, 3, 4] # List
zip(a, b) # Runtime error
运行上述代码会产生以下 RuntimeError
。
RuntimeError: Can not iterate over a module list or
tuple with a value that does not have a statically determinable length.
示例 3
支持具有相同容器类型但不同数据类型的两个可迭代对象
a = [1.3, 2.4]
b = [2, 3, 4]
zip(a, b) # Works
TorchScript 中的可迭代类型包括 Tensors
、lists
、tuples
、dictionaries
、strings
、torch.nn.ModuleList
和 torch.nn.ModuleDict
。
enumerate
语句¶
enumerate_stmt ::= enumerate([iterable])
参数必须是可迭代对象。
TorchScript 中的可迭代类型包括
Tensors
、lists
、tuples
、dictionaries
、strings
、torch.nn.ModuleList
和torch.nn.ModuleDict
。
Python 值¶
解析规则¶
当给定 Python 值时,TorchScript 会尝试通过以下五种不同的方式解析它
- 可编译的 Python 实现
当 Python 值由 TorchScript 可以编译的 Python 实现支持时,TorchScript 会编译并使用底层 Python 实现。
示例:
torch.jit.Attribute
- Op Python 包装器
当 Python 值是原生 PyTorch op 的包装器时,TorchScript 会发出相应的运算符。
示例:
torch.jit._logging.add_stat_value
- Python 对象标识匹配
对于 TorchScript 支持的有限的
torch.*
API 调用(以 Python 值的形式),TorchScript 尝试将 Python 值与集合中的每个项目进行匹配。匹配时,TorchScript 会生成相应的
SugaredValue
实例,其中包含这些值的降低逻辑。示例:
torch.jit.isinstance()
- 名称匹配
对于 Python 内置函数和常量,TorchScript 通过名称识别它们,并创建一个相应的
SugaredValue
实例来实现它们的功能。示例:
all()
- 值快照
对于来自无法识别模块的 Python 值,TorchScript 会尝试获取值的快照,并将其转换为正在编译的函数或方法的图中的常量。
示例:
math.pi
Python 内置函数支持¶
内置函数 |
支持级别 |
注释 |
---|---|---|
|
部分支持 |
仅支持 |
|
完整支持 |
|
|
完整支持 |
|
|
不支持 |
|
|
部分支持 |
仅支持 |
|
部分支持 |
仅支持 |
|
不支持 |
|
|
不支持 |
|
|
不支持 |
|
|
不支持 |
|
|
部分支持 |
仅支持 ASCII 字符集。 |
|
完整支持 |
|
|
不支持 |
|
|
不支持 |
|
|
不支持 |
|
|
完整支持 |
|
|
不支持 |
|
|
完整支持 |
|
|
完整支持 |
|
|
不支持 |
|
|
不支持 |
|
|
不支持 |
|
|
部分支持 |
不遵循 |
|
部分支持 |
不支持手动索引规范。| 不支持格式类型修饰符。 |
|
不支持 |
|
|
部分支持 |
属性名称必须是字符串文字。 |
|
不支持 |
|
|
部分支持 |
属性名称必须是字符串文字。 |
|
完整支持 |
|
|
部分支持 |
仅支持 |
|
完整支持 |
仅支持 |
|
不支持 |
|
|
部分支持 |
不支持 |
|
完整支持 |
当检查诸如 |
|
不支持 |
|
|
不支持 |
|
|
完整支持 |
|
|
完整支持 |
|
|
部分支持 |
仅支持 ASCII 字符集。 |
|
完整支持 |
|
|
部分支持 |
不支持 |
|
不支持 |
|
|
完整支持 |
|
|
不支持 |
|
|
不支持 |
|
|
部分支持 |
不支持 |
|
不支持 |
|
|
不支持 |
|
|
完整支持 |
|
|
部分支持 |
不支持 |
|
完整支持 |
|
|
部分支持 |
不支持 |
|
完整支持 |
|
|
部分支持 |
它只能在 |
|
不支持 |
|
|
不支持 |
|
|
完整支持 |
|
|
不支持 |
torch.* API¶
远程过程调用¶
TorchScript 支持 RPC API 的子集,该子集支持在指定的远程工作程序而不是本地运行函数。
具体来说,完全支持以下 API
torch.distributed.rpc.rpc_sync()
rpc_sync()
进行阻塞 RPC 调用,以在远程工作程序上运行函数。RPC 消息与 Python 代码的执行并行发送和接收。有关其用法和示例的更多详细信息,请参见
rpc_sync()
。
torch.distributed.rpc.rpc_async()
rpc_async()
进行非阻塞 RPC 调用,以在远程工作程序上运行函数。RPC 消息与 Python 代码的执行并行发送和接收。有关其用法和示例的更多详细信息,请参见
rpc_async()
。
torch.distributed.rpc.remote()
remote.()
在工作程序上执行远程调用,并获取远程引用RRef
作为返回值。有关其用法和示例的更多详细信息,请参见
remote()
。
异步执行¶
TorchScript 使您能够创建异步计算任务,以更好地利用计算资源。这是通过支持一系列仅在 TorchScript 中可用的 API 来完成的
类型注解¶
TorchScript 是静态类型的。它提供并支持一组实用程序来帮助注解变量和属性
torch.jit.annotate()
在 Python 3 样式类型提示效果不佳的地方为 TorchScript 提供类型提示。
一个常见的示例是为诸如
[]
之类的表达式注解类型。[]
默认被视为List[torch.Tensor]
。当需要不同的类型时,您可以使用此代码来提示 TorchScript:torch.jit.annotate(List[int], [])
。更多详细信息请参见
annotate()
torch.jit.Attribute
常见用例包括为
torch.nn.Module
属性提供类型提示。由于 TorchScript 不解析它们的__init__
方法,因此应在模块的__init__
方法中使用torch.jit.Attribute
而不是torch.jit.annotate
。更多详细信息请参见
Attribute()
torch.jit.Final
Python 的
typing.Final
的别名。torch.jit.Final
仅出于向后兼容性原因而保留。
元编程¶
TorchScript 提供了一组实用程序来促进元编程
torch.jit.is_scripting()
返回一个布尔值,指示当前程序是否由
torch.jit.script
编译。当在
assert
或if
语句中使用时,torch.jit.is_scripting()
的计算结果为False
的作用域或分支不会被编译。它的值可以在编译时静态评估,因此通常在
if
语句中使用,以阻止 TorchScript 编译其中一个分支。更多详细信息和示例请参见
is_scripting()
torch.jit.is_tracing()
返回一个布尔值,指示当前程序是否由
torch.jit.trace
/torch.jit.trace_module
跟踪。更多详细信息请参见
is_tracing()
@torch.jit.ignore
此装饰器向编译器指示应忽略函数或方法,并将其保留为 Python 函数。
这允许您在模型中保留尚未与 TorchScript 兼容的代码。
如果从 TorchScript 调用由
@torch.jit.ignore
修饰的函数,则忽略的函数会将调用分派到 Python 解释器。具有忽略函数的模型无法导出。
更多详细信息和示例请参见
ignore()
@torch.jit.unused
此装饰器向编译器指示应忽略函数或方法,并将其替换为引发异常。
这允许您在模型中保留尚未与 TorchScript 兼容的代码,并且仍然可以导出模型。
如果从 TorchScript 调用由
@torch.jit.unused
修饰的函数,则会引发运行时错误。更多详细信息和示例请参见
unused()
类型细化¶
torch.jit.isinstance()
返回一个布尔值,指示变量是否为指定的类型。
有关其用法和示例的更多详细信息,请参见
isinstance()
。