torch.export IR 规范¶
Export IR 是用于编译器的中间表示 (IR),与 MLIR 和 TorchScript 类似。它专门设计用于表达 PyTorch 程序的语义。Export IR 主要通过精简的运算列表表示计算,对控制流等动态特性的支持有限。
为了创建 Export IR 图,可以使用前端通过跟踪特化机制可靠地捕获 PyTorch 程序。生成的 Export IR 可以由后端进行优化和执行。目前可以通过 torch.export.export()
实现这一点。
本文档将涵盖的关键概念包括
ExportedProgram:包含 Export IR 程序的数据结构
Graph:由节点列表组成。
Nodes:表示运算、控制流以及存储在此节点上的元数据。
值由节点产生和消费。
类型与值和节点相关联。
还定义了值的大小和内存布局。
什么是 Export IR¶
Export IR 是 PyTorch 程序的基于图的中间表示 IR。Export IR 是在 torch.fx.Graph
之上实现的。换句话说,所有 Export IR 图都是有效的 FX 图,如果使用标准的 FX 语义解释,Export IR 可以被可靠地解释。一个推论是,导出的图可以通过标准的 FX 代码生成转换为有效的 Python 程序。
本文档将主要重点介绍 Export IR 在严格性方面与 FX 的不同之处,同时跳过与 FX 相似的部分。
ExportedProgram¶
顶层 Export IR 构造是 torch.export.ExportedProgram
类。它将 PyTorch 模型的计算图(通常是 torch.nn.Module
)与该模型使用的参数或权重捆绑在一起。
torch.export.ExportedProgram
类的一些值得注意的属性包括
graph_module
(torch.fx.GraphModule
):包含 PyTorch 模型展平计算图的数据结构。可以通过 ExportedProgram.graph 直接访问该图。graph_signature
(torch.export.ExportGraphSignature
):图签名,指定图中使用和修改的参数和缓冲区名称。参数和缓冲区不作为图的属性存储,而是作为图的输入被提升。graph_signature 用于跟踪这些参数和缓冲区的附加信息。state_dict
(Dict[str, Union[torch.Tensor, torch.nn.Parameter]]
):包含参数和缓冲区的数据结构。range_constraints
(Dict[sympy.Symbol, RangeConstraint]
):对于导出时具有数据依赖行为的程序,每个节点上的元数据将包含符号形状(例如s0
、i0
)。此属性将符号形状映射到其下限/上限范围。
Graph¶
Export IR Graph 是以 DAG(有向无环图)形式表示的 PyTorch 程序。此图中的每个节点表示特定的计算或运算,图的边由节点之间的引用组成。
我们可以将 Graph 视为具有以下模式
class Graph:
nodes: List[Node]
实际上,Export IR 的图通过 torch.fx.Graph
Python 类实现。
Export IR 图包含以下节点(节点将在下一节中更详细地描述)
0 个或多个 op 类型为
placeholder
的节点0 个或多个 op 类型为
call_function
的节点恰好 1 个 op 类型为
output
的节点
推论:最小的有效 Graph 将包含一个节点。即,nodes 永远不为空。
定义: Graph 的 placeholder
节点集合表示 GraphModule 的 Graph 的输入。output 节点表示 GraphModule 的 Graph 的输出。
示例
import torch
from torch import nn
class MyModule(nn.Module):
def forward(self, x, y):
return x + y
example_args = (torch.randn(1), torch.randn(1))
mod = torch.export.export(MyModule(), example_args)
print(mod.graph)
graph():
%x : [num_users=1] = placeholder[target=x]
%y : [num_users=1] = placeholder[target=y]
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %y), kwargs = {})
return (add,)
上图是 Graph 的文本表示,每行是一个节点。
节点¶
Node 表示特定的计算或运算,在 Python 中使用 torch.fx.Node
类表示。节点之间的边通过 Node 类的 args
属性表示为对其他节点的直接引用。使用相同的 FX 机制,我们可以表示计算图通常需要的以下运算,例如算子调用、placeholder(即输入)、条件语句和循环。
Node 具有以下模式
class Node:
name: str # name of node
op_name: str # type of operation
# interpretation of the fields below depends on op_name
target: [str|Callable]
args: List[object]
kwargs: Dict[str, object]
meta: Dict[str, object]
FX 文本格式
如上例所示,请注意每行都具有此格式
%<name>:[...] = <op_name>[target=<target>](args = (%arg1, %arg2, arg3, arg4, …)), kwargs = {"keyword": arg5})
此格式以紧凑形式捕获 Node 类中的所有内容,但 meta
除外。
具体来说
<name> 是节点的名称,如它在
node.name
中显示的那样。<op_name> 是
node.op
字段,必须是以下之一:<call_function>、<placeholder>、<get_attr> 或 <output>。<target> 是节点的 target,如
node.target
。此字段的含义取决于op_name
。args1, … args 4… 是
node.args
元组中列出的内容。如果列表中的值是torch.fx.Node
,则会特别用前导 %. 指示。
例如,对 add 算子的调用将显示为
%add1 = call_function[target = torch.op.aten.add.Tensor](args = (%x, %y), kwargs = {})
其中 %x
、%y
是另外两个名称为 x 和 y 的 Node。值得注意的是,字符串 torch.op.aten.add.Tensor
表示实际存储在 target 字段中的可调用对象,而不仅仅是其字符串名称。
此文本格式的最后一行是
return [add]
它是一个 op_name = output
的 Node,表示我们正在返回此元素。
call_function¶
一个 call_function
节点表示对算子的调用。
定义
函数式:如果一个可调用对象满足以下所有要求,我们称之为“函数式”
非变异:算子不修改其输入的值(对于张量,这包括元数据和数据)。
无副作用:算子不修改外部可见的状态,例如更改模块参数的值。
算子:是具有预定义模式的函数式可调用对象。此类算子示例包括函数式 ATen 算子。
在 FX 中的表示
%name = call_function[target = operator](args = (%x, %y, …), kwargs = {})
与普通 FX call_function 的区别
在 FX 图中,一个 call_function 可以引用任何可调用对象,而在 Export IR 中,我们将其限制为仅一部分选定的 ATen 算子、自定义算子和控制流算子。
在 Export IR 中,常量参数将嵌入到图中。
在 FX 图中,get_attr 节点可以表示读取图模块中存储的任何属性。然而,在 Export IR 中,这被限制为仅读取子模块,因为所有参数/缓冲区都将作为输入传递给图模块。
元数据¶
Node.meta
是附加到每个 FX 节点的字典。然而,FX 规范并未指定其中可以或将包含哪些元数据。Export IR 提供了更强的契约,特别是所有 call_function
节点都保证具有且仅具有以下元数据字段
node.meta["stack_trace"]
是包含引用原始 Python 源代码的 Python 堆栈跟踪的字符串。示例堆栈跟踪如下所示File "my_module.py", line 19, in forward return x + dummy_helper(y) File "helper_utility.py", line 89, in dummy_helper return y + 1
node.meta["val"]
描述了运行运算的输出。它可以是 <symint>、<FakeTensor>、List[Union[FakeTensor, SymInt]]
或None
类型。node.meta["nn_module_stack"]
描述了节点来源的torch.nn.Module
的“堆栈跟踪”,如果它来自torch.nn.Module
调用的话。例如,如果包含来自torch.nn.Sequential
模块内部的torch.nn.Linear
模块调用的addmm
算子的节点,则nn_module_stack
将类似于{'self_linear': ('self.linear', <class 'torch.nn.Linear'>), 'self_sequential': ('self.sequential', <class 'torch.nn.Sequential'>)}
node.meta["source_fn_stack"]
包含此节点在分解之前从哪个 torch 函数或叶子torch.nn.Module
类调用。例如,包含来自torch.nn.Linear
模块调用的addmm
算子的节点将在其source_fn
中包含torch.nn.Linear
,而包含来自torch.nn.functional.Linear
模块调用的addmm
算子的节点将在其source_fn
中包含torch.nn.functional.Linear
。
placeholder¶
Placeholder 表示图的输入。其语义与 FX 中完全相同。Placeholder 节点必须是图节点列表中的前 N 个节点。N 可以为零。
在 FX 中的表示
%name = placeholder[target = name](args = ())
target 字段是一个字符串,它是输入的名称。
args
,如果非空,其大小应为 1,表示此输入的默认值。
元数据
Placeholder 节点也具有 meta[‘val’]
,就像 call_function
节点一样。在这种情况下,val
字段表示图期望为此输入参数接收的输入形状/dtype。
output¶
一个 output 调用表示函数中的返回语句;因此它终止当前图。只有一个 output 节点,并且它始终是图的最后一个节点。
在 FX 中的表示
output[](args = (%something, …))
这与 torch.fx
中的语义完全相同。args
表示要返回的节点。
元数据
Output 节点具有与 call_function
节点相同的元数据。
get_attr¶
get_attr
节点表示从封装的 torch.fx.GraphModule
读取子模块。与通过 torch.fx.symbolic_trace()
生成的普通 FX 图不同,在普通 FX 图中,get_attr
节点用于从顶层 torch.fx.GraphModule
读取参数和缓冲区等属性,而在 Export IR 中,参数和缓冲区作为输入传递给图模块,并存储在顶层 torch.export.ExportedProgram
中。
在 FX 中的表示
%name = get_attr[target = name](args = ())
示例
考虑以下模型
from functorch.experimental.control_flow import cond
def true_fn(x):
return x.sin()
def false_fn(x):
return x.cos()
def f(x, y):
return cond(y, true_fn, false_fn, [x])
Graph
graph():
%x_1 : [num_users=1] = placeholder[target=x_1]
%y_1 : [num_users=1] = placeholder[target=y_1]
%true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
%false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
%conditional : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%y_1, %true_graph_0, %false_graph_0, [%x_1]), kwargs = {})
return conditional
行 %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
读取子模块 true_graph_0
,其中包含 sin
算子。
参考¶
SymInt¶
SymInt 是一个可以表示字面整数或代表整数(在 Python 中由 sympy.Symbol
类表示)的符号的对象。当 SymInt 是一个符号时,它描述了一个在编译时未知、其值仅在运行时已知的整数类型变量。
FakeTensor¶
FakeTensor 是包含张量元数据的对象。可以将其视为具有以下元数据。
class FakeTensor:
size: List[SymInt]
dtype: torch.dtype
device: torch.device
dim_order: List[int] # This doesn't exist yet
FakeTensor 的 size 字段是整数或 SymInt 的列表。如果存在 SymInt,则表示此张量具有动态形状。如果存在整数,则假定该张量将具有该精确的静态形状。TensorMeta 的秩永远不是动态的。dtype 字段表示该节点输出的 dtype。Edge IR 中没有隐式类型提升。FakeTensor 中没有 strides。
换句话说
如果 node.target 中的算子返回 Tensor,则
node.meta['val']
是描述该张量的 FakeTensor。如果 node.target 中的算子返回一个由 Tensor 组成的 n 元组,则
node.meta['val']
是描述每个张量的 FakeTensor 组成的 n 元组。如果 node.target 中的算子返回在编译时已知的 int/float/标量,则
node.meta['val']
为 None。如果 node.target 中的算子返回在编译时未知的 int/float/标量,则
node.meta['val']
的类型为 SymInt。
例如
aten::add
返回一个 Tensor;因此其 spec 将是 FakeTensor,包含此算子返回的张量的 dtype 和 size。aten::sym_size
返回一个整数;因此其 val 将是 SymInt,因为其值仅在运行时可用。max_pool2d_with_indexes
返回一个 (Tensor, Tensor) 元组;因此 spec 也将是由 FakeTensor 对象组成的 2 元组,第一个 TensorMeta 描述返回值的第一个元素,依此类推。
Python 代码
def add_one(x):
return torch.ops.aten(x, 1)
Graph
graph():
%ph_0 : [#users=1] = placeholder[target=ph_0]
%add_tensor : [#users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%ph_0, 1), kwargs = {})
return [add_tensor]
FakeTensor
FakeTensor(dtype=torch.int, size=[2,], device=CPU)
Pytree-able 类型¶
我们定义一种类型为“Pytree-able”,如果它要么是叶子类型,要么是包含其他 Pytree-able 类型的容器类型。
注意
pytree 的概念与此处为 JAX 记录的概念相同
以下类型定义为叶子类型
类型 |
定义 |
---|---|
Tensor |
|
标量 |
Python 中的任何数值类型,包括整数类型、浮点类型和零维张量。 |
int |
Python int(在 C++ 中绑定为 int64_t) |
float |
Python float(在 C++ 中绑定为 double) |
bool |
Python bool |
str |
Python 字符串 |
ScalarType |
|
Layout |
|
MemoryFormat |
|
设备 |
以下类型定义为容器类型
类型 |
定义 |
---|---|
元组 |
Python 元组 |
列表 |
Python 列表 |
字典 |
键为 Scalar 的 Python 字典 |
命名元组 |
Python namedtuple |
Dataclass |
必须通过 register_dataclass 注册 |
自定义类 |
使用 _register_pytree_node 定义的任何自定义类 |