快捷方式

torch.export IR 规范

导出 IR 是一种用于编译器的中间表示 (IR),它与 MLIR 和 TorchScript 类似。它专门设计用于表达 PyTorch 程序的语义。导出 IR 主要以简化的操作列表来表示计算,对动态性(如控制流)的支持有限。

要创建导出 IR 图,可以使用前端,通过跟踪专业化机制可靠地捕获 PyTorch 程序。然后,导出的 IR 可以由后端进行优化和执行。这可以通过 torch.export.export() 立即完成。

本文档将涵盖的关键概念包括

  • ExportedProgram:包含导出 IR 程序的数据结构

  • 图:由节点列表组成。

  • 节点:表示操作、控制流和存储在此节点上的元数据。

  • 值由节点产生和消耗。

  • 类型与值和节点相关联。

  • 还定义了值的大小和内存布局。

假设

本文档假设读者充分熟悉 PyTorch,特别是 torch.fx 及其相关工具。因此,它将停止描述 torch.fx 文档和论文中已有的内容。

什么是导出 IR

导出 IR 是 PyTorch 程序的基于图的中间表示 IR。导出 IR 基于 torch.fx.Graph 实现。换句话说,所有导出 IR 图也是有效的 FX 图,如果使用标准 FX 语义解释,则导出 IR 可以被可靠地解释。一个含义是,导出的图可以通过标准 FX 代码生成转换为有效的 Python 程序。

本文档将主要关注突出显示导出 IR 在严格性方面与 FX 不同的区域,同时跳过与 FX 相似的部分。

ExportedProgram

顶层导出 IR 构造是 torch.export.ExportedProgram 类。它将 PyTorch 模型的计算图(通常是 torch.nn.Module)与该模型使用的参数或权重捆绑在一起。

torch.export.ExportedProgram 类的一些值得注意的属性是

  • graph_module (torch.fx.GraphModule):包含 PyTorch 模型扁平化计算图的数据结构。该图可以直接通过 ExportedProgram.graph 访问。

  • graph_signature (torch.export.ExportGraphSignature):图签名,它指定图中使用的和变异的参数和缓冲区名称。参数和缓冲区不是作为图的属性存储,而是被提升为图的输入。graph_signature 用于跟踪关于这些参数和缓冲区的附加信息。

  • state_dict (Dict[str, Union[torch.Tensor, torch.nn.Parameter]]):包含参数和缓冲区的数据结构。

  • range_constraints (Dict[sympy.Symbol, RangeConstraint]):对于使用数据相关行为导出的程序,每个节点上的元数据将包含符号形状(看起来像 s0, i0)。此属性将符号形状映射到其下限/上限范围。

导出 IR 图是以 DAG(有向无环图)形式表示的 PyTorch 程序。此图中的每个节点表示特定的计算或操作,并且此图的边由节点之间的引用组成。

我们可以将图视为具有以下模式

class Graph:
  nodes: List[Node]

实际上,导出 IR 的图被实现为 torch.fx.Graph Python 类。

导出 IR 图包含以下节点(节点将在下一节中更详细地描述)

  • 0 个或多个 op 类型为 placeholder 的节点

  • 0 个或多个 op 类型为 call_function 的节点

  • 恰好 1 个 op 类型为 output 的节点

推论: 最小的有效图将由一个节点组成。即节点永远不为空。

定义: 图的 placeholder 节点集表示 GraphModule 的图的输入output 节点表示 GraphModule 的图的输出

示例

import torch
from torch import nn

class MyModule(nn.Module):

    def forward(self, x, y):
      return x + y

example_args = (torch.randn(1), torch.randn(1))
mod = torch.export.export(MyModule(), example_args)
print(mod.graph)
graph():
  %x : [num_users=1] = placeholder[target=x]
  %y : [num_users=1] = placeholder[target=y]
  %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %y), kwargs = {})
  return (add,)

以上是图的文本表示,每行是一个节点。

节点

节点表示特定的计算或操作,并在 Python 中使用 torch.fx.Node 类表示。节点之间的边通过 Node 类的 args 属性表示为对其他节点的直接引用。使用相同的 FX 机制,我们可以表示计算图通常需要的以下操作,例如算子调用、占位符(又名输入)、条件和循环。

节点具有以下模式

class Node:
  name: str # name of node
  op_name: str  # type of operation

  # interpretation of the fields below depends on op_name
  target: [str|Callable]
  args: List[object]
  kwargs: Dict[str, object]
  meta: Dict[str, object]

FX 文本格式

如上面的示例所示,请注意,每行都具有以下格式

%<name>:[...] = <op_name>[target=<target>](args = (%arg1, %arg2, arg3, arg4, …)), kwargs = {"keyword": arg5})

此格式捕获 Node 类中存在的所有内容,除了 meta 之外,以紧凑的格式。

具体而言

  • <name> 是节点名称,因为它将出现在 node.name 中。

  • <op_name>node.op 字段,它必须是以下之一:<call_function><placeholder><get_attr><output>

  • <target> 是节点的 target,如 node.target。此字段的含义取决于 op_name

  • args1, … args 4…node.args 元组中列出的内容。如果列表中的值是 torch.fx.Node,则将用前导 % 特别指示。

例如,调用 add 算子将显示为

%add1 = call_function[target = torch.op.aten.add.Tensor](args = (%x, %y), kwargs = {})

其中 %x, %y 是另外两个名称分别为 x 和 y 的节点。值得注意的是,字符串 torch.op.aten.add.Tensor 表示实际存储在 target 字段中的可调用对象,而不仅仅是其字符串名称。

此文本格式的最后一行是

return [add]

这是一个 op_name = output 的节点,表示我们正在返回这一个元素。

call_function

call_function 节点表示对算子的调用。

定义

  • 函数式: 如果一个可调用对象满足以下所有要求,我们称其为“函数式”

    • 非变异:算子不会变异其输入的值(对于张量,这包括元数据和数据)。

    • 无副作用:算子不会变异从外部可见的状态,例如更改模块参数的值。

  • 算子: 是具有预定义模式的函数式可调用对象。此类算子的示例包括函数式 ATen 算子。

在 FX 中的表示

%name = call_function[target = operator](args = (%x, %y, …), kwargs = {})

与原始 FX call_function 的区别

  1. 在 FX 图中,call_function 可以引用任何可调用对象,在导出 IR 中,我们将其限制为仅 ATen 算子、自定义算子和控制流算子的选定子集。

  2. 在导出 IR 中,常量参数将嵌入在图中。

  3. 在 FX 图中,get_attr 节点可以表示读取存储在图模块中的任何属性。但是,在导出 IR 中,这被限制为仅读取子模块,因为所有参数/缓冲区都将作为输入传递到图模块。

元数据

Node.meta 是附加到每个 FX 节点的 dict。但是,FX 规范未指定可能或将会有哪些元数据。导出 IR 提供了更强的约定,特别是所有 call_function 节点将保证具有且仅具有以下元数据字段

  • node.meta["stack_trace"] 是一个字符串,其中包含引用原始 Python 源代码的 Python 堆栈跟踪。堆栈跟踪示例类似于

    File "my_module.py", line 19, in forward
    return x + dummy_helper(y)
    File "helper_utility.py", line 89, in dummy_helper
    return y + 1
    
  • node.meta["val"] 描述了运行操作的输出。它可以是 <symint><FakeTensor>List[Union[FakeTensor, SymInt]]None 类型。

  • node.meta["nn_module_stack"] 描述了 torch.nn.Module 的“堆栈跟踪”,节点来自该模块,如果它来自 torch.nn.Module 调用。例如,如果一个节点包含从 torch.nn.Sequential 模块内部的 torch.nn.Linear 模块调用的 addmm op,则 nn_module_stack 将类似于

    {'self_linear': ('self.linear', <class 'torch.nn.Linear'>), 'self_sequential': ('self.sequential', <class 'torch.nn.Sequential'>)}
    
  • node.meta["source_fn_stack"] 包含 torch 函数或叶子 torch.nn.Module 类,该节点在分解之前从中调用。例如,来自 torch.nn.Linear 模块调用的包含 addmm op 的节点将在其 source_fn 中包含 torch.nn.Linear,并且来自 torch.nn.functional.Linear 模块调用的包含 addmm op 的节点将在其 source_fn 中包含 torch.nn.functional.Linear

placeholder

Placeholder 表示图的输入。它的语义与 FX 中的完全相同。Placeholder 节点必须是图中节点列表中的前 N 个节点。N 可以为零。

在 FX 中的表示

%name = placeholder[target = name](args = ())

target 字段是一个字符串,它是输入的名称。

args(如果非空)的大小应为 1,表示此输入的默认值。

元数据

Placeholder 节点也具有 meta[‘val’],如 call_function 节点一样。在这种情况下,val 字段表示图预期为此输入参数接收的输入形状/dtype。

output

输出调用表示函数中的返回语句;因此它终止当前图。有且仅有一个输出节点,它将始终是图的最后一个节点。

在 FX 中的表示

output[](args = (%something, …))

这与 torch.fx 中的语义完全相同。args 表示要返回的节点。

元数据

输出节点具有与 call_function 节点相同的元数据。

get_attr

get_attr 节点表示从封装的 torch.fx.GraphModule 读取子模块。与来自 torch.fx.symbolic_trace() 的原始 FX 图不同,在原始 FX 图中,get_attr 节点用于从顶层 torch.fx.GraphModule 读取参数和缓冲区等属性,参数和缓冲区作为输入传递到图模块,并存储在顶层 torch.export.ExportedProgram 中。

在 FX 中的表示

%name = get_attr[target = name](args = ())

示例

考虑以下模型

from functorch.experimental.control_flow import cond

def true_fn(x):
    return x.sin()

def false_fn(x):
    return x.cos()

def f(x, y):
    return cond(y, true_fn, false_fn, [x])

graph():
    %x_1 : [num_users=1] = placeholder[target=x_1]
    %y_1 : [num_users=1] = placeholder[target=y_1]
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %conditional : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%y_1, %true_graph_0, %false_graph_0, [%x_1]), kwargs = {})
    return conditional

%true_graph_0 : [num_users=1] = get_attr[target=true_graph_0] 读取包含 sin 算子的子模块 true_graph_0

参考

SymInt

SymInt 是一个对象,它可以是字面整数或表示整数的符号(在 Python 中由 sympy.Symbol 类表示)。当 SymInt 是符号时,它描述一个整数类型的变量,该变量在编译时对图未知,也就是说,它的值仅在运行时才知道。

FakeTensor

FakeTensor 是一个包含张量元数据的对象。它可以被视为具有以下元数据。

class FakeTensor:
  size: List[SymInt]
  dtype: torch.dtype
  device: torch.device
  dim_order: List[int]  # This doesn't exist yet

FakeTensor 的 size 字段是整数或 SymInts 的列表。如果存在 SymInts,则表示此张量具有动态形状。如果存在整数,则假定张量将具有该确切的静态形状。TensorMeta 的秩永远不是动态的。dtype 字段表示该节点输出的 dtype。Edge IR 中没有隐式类型提升。FakeTensor 中没有步幅。

换句话说

  • 如果 node.target 中的算子返回张量,则 node.meta['val'] 是描述该张量的 FakeTensor。

  • 如果 node.target 中的算子返回张量的 n 元组,则 node.meta['val'] 是描述每个张量的 FakeTensor 的 n 元组。

  • 如果 node.target 中的算子返回编译时已知的 int/float/标量,则 node.meta['val'] 为 None。

  • 如果 node.target 中的算子返回编译时未知的 int/float/标量,则 node.meta['val'] 的类型为 SymInt。

例如

  • aten::add 返回张量;因此其规范将是 FakeTensor,其中包含此算子返回的张量的 dtype 和 size。

  • aten::sym_size 返回一个整数;因此其 val 将是 SymInt,因为其值仅在运行时可用。

  • max_pool2d_with_indexes 返回 (Tensor, Tensor) 的元组;因此规范也将是 FakeTensor 对象的 2 元组,第一个 TensorMeta 描述返回值等的第一个元素。

Python 代码

def add_one(x):
  return torch.ops.aten(x, 1)

graph():
  %ph_0 : [#users=1] = placeholder[target=ph_0]
  %add_tensor : [#users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%ph_0, 1), kwargs = {})
  return [add_tensor]

FakeTensor

FakeTensor(dtype=torch.int, size=[2,], device=CPU)

Pytree-able 类型

我们将类型定义为“Pytree-able”,如果它是叶子类型或包含其他 Pytree-able 类型的容器类型。

注意

pytree 的概念与 此处为 JAX 文档记录的概念相同

以下类型定义为 叶子类型

类型

定义

张量

torch.Tensor

标量

来自 Python 的任何数值类型,包括整数类型、浮点类型和零维张量。

int

Python int(在 C++ 中绑定为 int64_t)

float

Python float(在 C++ 中绑定为 double)

bool

Python bool

str

Python 字符串

ScalarType

torch.dtype

Layout

torch.layout

MemoryFormat

torch.memory_format

Device

torch.device

以下类型定义为 容器类型

类型

定义

元组

Python 元组

列表

Python 列表

字典

具有标量键的 Python 字典

NamedTuple

Python namedtuple

Dataclass

必须通过 register_dataclass 注册

自定义类

使用 _register_pytree_node 定义的任何自定义类

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题解答

查看资源