快捷方式

TorchScript 语言参考

TorchScript 是 Python 的静态类型子集,可以直接编写(使用 @torch.jit.script 装饰器)或通过跟踪从 Python 代码自动生成。使用跟踪时,代码会自动转换为 Python 的这个子集,仅记录 tensor 上的实际算子,而简单地执行并丢弃周围的其他 Python 代码。

使用 @torch.jit.script 装饰器直接编写 TorchScript 时,程序员必须只使用 TorchScript 支持的 Python 子集。本节文档介绍了 TorchScript 中支持的内容,就像它是独立语言的参考一样。本参考中未提及的任何 Python 特性都不是 TorchScript 的一部分。有关可用 PyTorch tensor 方法、模块和函数的完整参考,请参见 内置函数

作为 Python 的子集,任何有效的 TorchScript 函数也是有效的 Python 函数。这使得禁用 TorchScript 并使用标准的 Python 工具(例如 pdb)调试函数成为可能。反之则不然:许多有效的 Python 程序并不是有效的 TorchScript 程序。相反,TorchScript 专门关注表示 PyTorch 中神经网络模型所需的 Python 特性。

类型

TorchScript 与完整 Python 语言最大的区别在于,TorchScript 只支持表达神经网络模型所需的一小部分类型。具体来说,TorchScript 支持以下类型:

类型

描述

Tensor

任何 dtype、维度或后端的 PyTorch tensor

Tuple[T0, T1, ..., TN]

包含子类型 T0T1 等的元组(例如 Tuple[Tensor, Tensor]

bool

布尔值

int

标量整数

float

标量浮点数

str

字符串

List[T]

其中所有成员都是类型 T 的列表

Optional[T]

一个值,可以是 None 或类型 T

Dict[K, V]

键类型为 K、值类型为 V 的字典。只允许将 strintfloat 作为键类型。

T

一个 TorchScript 类

E

一个 TorchScript 枚举

NamedTuple[T0, T1, ...]

一个 collections.namedtuple 元组类型

Union[T0, T1, ...]

子类型 T0T1 等之一

与 Python 不同,TorchScript 函数中的每个变量都必须具有单一的静态类型。这使得优化 TorchScript 函数更加容易。

示例(类型不匹配)

import torch

@torch.jit.script
def an_error(x):
    if x:
        r = torch.rand(1)
    else:
        r = 4
    return r
Traceback (most recent call last):
  ...
RuntimeError: ...

Type mismatch: r is set to type Tensor in the true branch and type int in the false branch:
@torch.jit.script
def an_error(x):
    if x:
    ~~~~~
        r = torch.rand(1)
        ~~~~~~~~~~~~~~~~~
    else:
    ~~~~~
        r = 4
        ~~~~~ <--- HERE
    return r
and was used here:
    else:
        r = 4
    return r
           ~ <--- HERE...

不支持的类型构造

TorchScript 不支持 typing 模块的所有特性和类型。其中一些是更基础的东西,未来不太可能添加,而另一些则可能会在用户需求足够多且成为优先事项时添加。

TorchScript 中不支持 typing 模块的以下类型和特性。

项目

描述

typing.Any

typing.Any 目前正在开发中,尚未发布

typing.NoReturn

未实现

typing.Sequence

未实现

typing.Callable

未实现

typing.Literal

未实现

typing.ClassVar

未实现

typing.Final

对于模块属性的类属性注解支持此项,但对于函数不支持

typing.AnyStr

TorchScript 不支持 bytes,因此不使用此类型

typing.overload

typing.overload 目前正在开发中,尚未发布

类型别名

未实现

名义子类型 vs 结构子类型

名义类型正在开发中,但结构类型尚未开发

NewType

不太可能实现

泛型

不太可能实现

本文档中未明确列出的 typing 模块中的任何其他功能均不受支持。

默认类型

默认情况下,TorchScript 函数的所有参数都假定为 Tensor。要指定 TorchScript 函数的参数是另一种类型,可以使用上面列出的 MyPy 风格的类型注解。

import torch

@torch.jit.script
def foo(x, tup):
    # type: (int, Tuple[Tensor, Tensor]) -> Tensor
    t0, t1 = tup
    return t0 + t1 + x

print(foo(3, (torch.rand(3), torch.rand(3))))

注意

也可以使用 typing 模块中的 Python 3 类型提示来注解类型。

import torch
from typing import Tuple

@torch.jit.script
def foo(x: int, tup: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
    t0, t1 = tup
    return t0 + t1 + x

print(foo(3, (torch.rand(3), torch.rand(3))))

空列表假定为 List[Tensor],空字典假定为 Dict[str, Tensor]。要实例化其他类型的空列表或字典,请使用Python 3 类型提示

示例(Python 3 的类型注解)

import torch
import torch.nn as nn
from typing import Dict, List, Tuple

class EmptyDataStructures(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x: torch.Tensor) -> Tuple[List[Tuple[int, float]], Dict[str, int]]:
        # This annotates the list to be a `List[Tuple[int, float]]`
        my_list: List[Tuple[int, float]] = []
        for i in range(10):
            my_list.append((i, x.item()))

        my_dict: Dict[str, int] = {}
        return my_list, my_dict

x = torch.jit.script(EmptyDataStructures())

Optional 类型细化

当在 if 语句的条件内进行与 None 的比较或在 assert 中检查时,TorchScript 将细化类型为 Optional[T] 的变量的类型。编译器可以处理与 andornot 组合的多个 None 检查。对于未明确编写的 if 语句的 else 块,也会发生细化。

None 检查必须在 if 语句的条件内;将 None 检查赋值给变量并在 if 语句的条件中使用它不会细化检查中变量的类型。只有局部变量会被细化,像 self.x 这样的属性不会,必须赋值给局部变量才能被细化。

示例(细化参数和局部变量的类型)

import torch
import torch.nn as nn
from typing import Optional

class M(nn.Module):
    z: Optional[int]

    def __init__(self, z):
        super().__init__()
        # If `z` is None, its type cannot be inferred, so it must
        # be specified (above)
        self.z = z

    def forward(self, x, y, z):
        # type: (Optional[int], Optional[int], Optional[int]) -> int
        if x is None:
            x = 1
            x = x + 1

        # Refinement for an attribute by assigning it to a local
        z = self.z
        if y is not None and z is not None:
            x = y + z

        # Refinement via an `assert`
        assert z is not None
        x += z
        return x

module = torch.jit.script(M(2))
module = torch.jit.script(M(None))

TorchScript 类

警告

TorchScript 类支持是实验性的。目前它最适合简单的记录类类型(可以将其视为带有方法的 NamedTuple)。

如果使用 @torch.jit.script 进行注解,Python 类可以在 TorchScript 中使用,这类似于声明 TorchScript 函数的方式

@torch.jit.script
class Foo:
  def __init__(self, x, y):
    self.x = x

  def aug_add_x(self, inc):
    self.x += inc

此子集受到限制

  • 所有函数都必须是有效的 TorchScript 函数(包括 __init__())。

  • 类必须是 new-style 类,因为我们使用 __new__() 和 pybind11 来构造它们。

  • TorchScript 类是静态类型的。成员只能通过在 __init__() 方法中赋值给 self 来声明。

    例如,在 __init__() 方法之外赋值给 self

    @torch.jit.script
    class Foo:
      def assign_x(self):
        self.x = torch.rand(2, 3)
    

    将导致

    RuntimeError:
    Tried to set nonexistent attribute: x. Did you forget to initialize it in __init__()?:
    def assign_x(self):
      self.x = torch.rand(2, 3)
      ~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
    
  • 类主体中除了方法定义之外不允许有任何表达式。

  • 不支持继承或任何其他多态策略,除了继承 object 以指定 new-style 类。

定义类后,它可以在 TorchScript 和 Python 中像任何其他 TorchScript 类型一样互换使用

# Declare a TorchScript class
@torch.jit.script
class Pair:
  def __init__(self, first, second):
    self.first = first
    self.second = second

@torch.jit.script
def sum_pair(p):
  # type: (Pair) -> Tensor
  return p.first + p.second

p = Pair(torch.rand(2, 3), torch.rand(2, 3))
print(sum_pair(p))

TorchScript 枚举

Python 枚举可以在 TorchScript 中使用,无需任何额外注解或代码

from enum import Enum


class Color(Enum):
    RED = 1
    GREEN = 2

@torch.jit.script
def enum_fn(x: Color, y: Color) -> bool:
    if x == Color.RED:
        return True

    return x == y

定义枚举后,它可以在 TorchScript 和 Python 中像任何其他 TorchScript 类型一样互换使用。枚举值的类型必须是 intfloatstr。所有值必须具有相同的类型;不支持枚举值的异构类型。

具名元组

collections.namedtuple 产生的类型可以在 TorchScript 中使用。

import torch
import collections

Point = collections.namedtuple('Point', ['x', 'y'])

@torch.jit.script
def total(point):
    # type: (Point) -> Tensor
    return point.x + point.y

p = Point(x=torch.rand(3), y=torch.rand(3))
print(total(p))

可迭代对象

某些函数(例如 zipenumerate)只能作用于可迭代类型。TorchScript 中的可迭代类型包括 Tensor、列表、元组、字典、字符串、torch.nn.ModuleListtorch.nn.ModuleDict

表达式

支持以下 Python 表达式。

字面量

True
False
None
'string literals'
"string literals"
3  # interpreted as int
3.4  # interpreted as a float

列表构建

空列表假定类型为 List[Tensor]。其他列表字面量的类型派生自成员的类型。有关更多详细信息,请参见默认类型

[3, 4]
[]
[torch.rand(3), torch.rand(4)]

元组构建

(3, 4)
(3,)

字典构建

空字典假定类型为 Dict[str, Tensor]。其他字典字面量的类型派生自成员的类型。有关更多详细信息,请参见默认类型

{'hello': 3}
{}
{'a': torch.rand(3), 'b': torch.rand(4)}

变量

有关变量如何解析,请参见变量解析

my_variable_name

算术运算符

a + b
a - b
a * b
a / b
a ^ b
a @ b

比较运算符

a == b
a != b
a < b
a > b
a <= b
a >= b

逻辑运算符

a and b
a or b
not b

下标和切片

t[0]
t[-1]
t[0:2]
t[1:]
t[:1]
t[:]
t[0, 1]
t[0, 1:2]
t[0, :1]
t[-1, 1:, 0]
t[1:, -1, 0]
t[i:j, i]

函数调用

调用内置函数

torch.rand(3, dtype=torch.int)

调用其他 script 函数

import torch

@torch.jit.script
def foo(x):
    return x + 1

@torch.jit.script
def bar(x):
    return foo(x)

方法调用

调用内置类型(如 tensor)的方法:x.mm(y)

对于模块,必须先编译方法才能调用它们。TorchScript 编译器在编译其他方法时会递归编译它看到的方法。默认情况下,编译从 forward 方法开始。 forward 调用的任何方法都会被编译,这些方法调用的任何方法也会被编译,依此类推。要从 forward 以外的方法开始编译,请使用 @torch.jit.export 装饰器(forward 隐式标记为 @torch.jit.export)。

直接调用子模块(例如 self.resnet(input))等同于调用其 forward 方法(例如 self.resnet.forward(input))。

import torch
import torch.nn as nn
import torchvision

class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        means = torch.tensor([103.939, 116.779, 123.68])
        self.means = torch.nn.Parameter(means.resize_(1, 3, 1, 1))
        resnet = torchvision.models.resnet18()
        self.resnet = torch.jit.trace(resnet, torch.rand(1, 3, 224, 224))

    def helper(self, input):
        return self.resnet(input - self.means)

    def forward(self, input):
        return self.helper(input)

    # Since nothing in the model calls `top_level_method`, the compiler
    # must be explicitly told to compile this method
    @torch.jit.export
    def top_level_method(self, input):
        return self.other_helper(input)

    def other_helper(self, input):
        return input + 10

# `my_script_module` will have the compiled methods `forward`, `helper`,
# `top_level_method`, and `other_helper`
my_script_module = torch.jit.script(MyModule())

三元表达式

x if x > y else y

类型转换

float(ten)
int(3.5)
bool(ten)
str(2)``

访问模块参数

self.my_parameter
self.my_submodule.my_parameter

语句

TorchScript 支持以下类型的语句

简单赋值

a = b
a += b # short-hand for a = a + b, does not operate in-place on a
a -= b

模式匹配赋值

a, b = tuple_or_list
a, b, *c = a_tuple

多重赋值

a = b, c = tup

If 语句

if a < 4:
    r = -a
elif a < 3:
    r = a + a
else:
    r = 3 * a

除了布尔值,浮点数、整数和 Tensor 也可以在条件语句中使用,并会被隐式转换为布尔值。

While 循环

a = 0
while a < 4:
    print(a)
    a += 1

带 range 的 For 循环

x = 0
for i in range(10):
    x *= i

遍历元组的 For 循环

这些会展开循环,为元组的每个成员生成一个主体。主体必须对每个成员进行正确的类型检查。

tup = (3, torch.rand(4))
for x in tup:
    print(x)

遍历常量 nn.ModuleList 的 For 循环

要在已编译的方法中使用 nn.ModuleList,必须通过将属性名称添加到类型的 __constants__ 列表中来将其标记为常量。遍历 nn.ModuleList 的 For 循环会在编译时展开循环主体,并处理常量模块列表的每个成员。

class SubModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(2))

    def forward(self, input):
        return self.weight + input

class MyModule(torch.nn.Module):
    __constants__ = ['mods']

    def __init__(self):
        super().__init__()
        self.mods = torch.nn.ModuleList([SubModule() for i in range(10)])

    def forward(self, v):
        for module in self.mods:
            v = module(v)
        return v


m = torch.jit.script(MyModule())

Break 和 Continue

for i in range(5):
    if i == 1:
        continue
    if i == 3:
        break
    print(i)

Return

return a, b

变量解析

TorchScript 支持 Python 变量解析(即作用域)规则的一个子集。局部变量的行为与 Python 中相同,但有一个限制:变量在通过函数的所有路径上必须具有相同的类型。如果变量在 if 语句的不同分支上具有不同的类型,则在 if 语句结束之后使用它是错误的。

同样,如果在函数中只沿某些路径定义了变量,则不允许使用该变量。

示例

@torch.jit.script
def foo(x):
    if x < 0:
        y = 4
    print(y)
Traceback (most recent call last):
  ...
RuntimeError: ...

y is not defined in the false branch...
@torch.jit.script...
def foo(x):
    if x < 0:
    ~~~~~~~~~
        y = 4
        ~~~~~ <--- HERE
    print(y)
and was used here:
    if x < 0:
        y = 4
    print(y)
          ~ <--- HERE...

非局部变量在定义函数时在编译时解析为 Python 值。然后使用Python 值的使用中描述的规则将这些值转换为 TorchScript 值。

Python 值的使用

为了让编写 TorchScript 更方便,我们允许脚本代码引用周围作用域中的 Python 值。例如,每当引用 torch 时,TorchScript 编译器实际上是在函数声明时将其解析为 torch Python 模块。这些 Python 值不是 TorchScript 的一等公民。相反,它们在编译时被“去糖化”为 TorchScript 支持的基本类型。这取决于编译时引用的 Python 值的动态类型。本节描述了在 TorchScript 中访问 Python 值时使用的规则。

函数

TorchScript 可以调用 Python 函数。在将模型逐步转换为 TorchScript 时,此功能非常有用。可以逐个函数地将模型移动到 TorchScript,同时保留对 Python 函数的调用。这样您就可以在转换过程中逐步检查模型的正确性。

torch.jit.is_scripting()[source][source]

在编译时返回 True,否则返回 False 的函数。这在使用 @unused 装饰器时特别有用,可以将模型中尚未与 TorchScript 兼容的代码保留下来。.. testcode

import torch

@torch.jit.unused
def unsupported_linear_op(x):
    return x

def linear(x):
    if torch.jit.is_scripting():
        return torch.linear(x)
    else:
        return unsupported_linear_op(x)
返回类型

bool

torch.jit.is_tracing()[source][source]

返回布尔值。

在跟踪时(如果在使用 torch.jit.trace 跟踪代码期间调用了某个函数)返回 True,否则返回 False

在 Python 模块上查找属性

TorchScript 可以查找模块上的属性。通过这种方式访问内置函数,例如 torch.add。这使得 TorchScript 可以调用其他模块中定义的函数。

Python 定义的常量

TorchScript 还提供了使用 Python 中定义的常量的方法。这些常量可以用于将超参数硬编码到函数中,或定义通用常量。有两种方法可以指定将 Python 值视为常量。

  1. 作为模块属性查找的值假定为常量

import math
import torch

@torch.jit.script
def fn():
    return math.pi
  1. ScriptModule 的属性可以通过使用 Final[T] 注解来标记为常量

import torch
import torch.nn as nn

class Foo(nn.Module):
    # `Final` from the `typing_extensions` module can also be used
    a : torch.jit.Final[int]

    def __init__(self):
        super().__init__()
        self.a = 1 + 4

    def forward(self, input):
        return self.a + input

f = torch.jit.script(Foo())

支持的常量 Python 类型有

  • int

  • float

  • bool

  • torch.device

  • torch.layout

  • torch.dtype

  • 包含支持类型的元组

  • 可在 TorchScript for 循环中使用的 torch.nn.ModuleList

模块属性

torch.nn.Parameter 包装器和 register_buffer 可用于将 tensor 赋值给模块。如果其他赋值给已编译模块的值的类型可以推断,则这些值将被添加到已编译的模块中。TorchScript 中所有可用的类型都可以用作模块属性。Tensor 属性在语义上与 buffer 相同。空列表和字典以及 None 值的类型无法推断,必须通过 PEP 526 风格的类注解来指定。如果无法推断类型且未明确注解,则不会将其作为属性添加到生成的 ScriptModule 中。

示例

from typing import List, Dict

class Foo(nn.Module):
    # `words` is initialized as an empty list, so its type must be specified
    words: List[str]

    # The type could potentially be inferred if `a_dict` (below) was not
    # empty, but this annotation ensures `some_dict` will be made into the
    # proper type
    some_dict: Dict[str, int]

    def __init__(self, a_dict):
        super().__init__()
        self.words = []
        self.some_dict = a_dict

        # `int`s can be inferred
        self.my_int = 10

    def forward(self, input):
        # type: (str) -> int
        self.words.append(input)
        return self.some_dict[input] + self.my_int

f = torch.jit.script(Foo({'hi': 2}))

文档

访问 PyTorch 的完整开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得问题解答

查看资源