torch.export IR 规范¶
Export IR 是编译器的中间表示 (IR),与 MLIR 和 TorchScript 类似。它专门设计用于表达 PyTorch 程序的语义。Export IR 主要以简化的操作列表的形式表示计算,对控制流等动态性的支持有限。
要创建 Export IR 图,可以使用前端,该前端通过跟踪专门化机制来可靠地捕获 PyTorch 程序。然后,后端可以优化和执行生成的 Export IR。这可以通过 torch.export.export()
来实现。
本文档将介绍的关键概念包括
ExportedProgram:包含 Export IR 程序的数据结构
Graph:由节点列表组成。
Nodes:表示操作、控制流和存储在此节点上的元数据。
节点生成和使用值。
类型与值和节点关联。
还定义了值的尺寸和内存布局。
什么是 Export IR¶
Export IR 是 PyTorch 程序的基于图的中间表示 IR。Export IR 建立在 torch.fx.Graph
之上。换句话说,所有 Export IR 图也是有效的 FX 图,如果使用标准 FX 语义解释,则可以可靠地解释 Export IR。一个含义是,导出的图可以通过标准 FX 代码生成转换为有效的 Python 程序。
本文档将主要关注 Export IR 与 FX 在严格性方面的差异,同时跳过与 FX 相似的部分。
ExportedProgram¶
Export IR 的顶级构造是一个 torch.export.ExportedProgram
类。它将 PyTorch 模型的计算图(通常是 torch.nn.Module
)与该模型使用的参数或权重捆绑在一起。
torch.export.ExportedProgram
类的某些值得注意的属性是
graph_module
(torch.fx.GraphModule
): 包含 PyTorch 模型的扁平化计算图的数据结构。可以通过 ExportedProgram.graph 直接访问该图。graph_signature
(torch.export.ExportGraphSignature
): 图签名,指定图中使用和修改的参数和缓冲区名称。参数和缓冲区不是作为图的属性存储,而是作为图的输入提升。graph_signature 用于跟踪有关这些参数和缓冲区的附加信息。state_dict
(Dict[str, Union[torch.Tensor, torch.nn.Parameter]]
): 包含参数和缓冲区的数据结构。range_constraints
(Dict[sympy.Symbol, RangeConstraint]
): 对于使用数据相关行为导出的程序,每个节点上的元数据将包含符号形状(看起来像s0
、i0
)。此属性将符号形状映射到其下限/上限范围。
Graph¶
Export IR 图是以 DAG(有向无环图)形式表示的 PyTorch 程序。此图中的每个节点都表示特定计算或操作,此图的边由节点之间的引用组成。
我们可以将图视为具有以下模式
class Graph:
nodes: List[Node]
在实践中,Export IR 的图由 torch.fx.Graph
Python 类实现。
Export IR 图包含以下节点(节点将在下一节中详细介绍)
0 个或多个 op 类型为
placeholder
的节点0 个或多个 op 类型为
call_function
的节点恰好 1 个 op 类型为
output
的节点
推论:最小的有效图将只有一个节点,即节点永远不为空。
定义: 图中的一组 placeholder
节点表示 GraphModule 图的 **输入**。图的 output 节点表示 GraphModule 图的 **输出**。
示例
from torch import nn
class MyModule(nn.Module):
def forward(self, x, y):
return x + y
mod = torch.export.export(MyModule())
print(mod.graph)
以上是图的文本表示,每行都是一个节点。
节点¶
节点表示特定计算或操作,在 Python 中使用 torch.fx.Node
类表示。节点之间的边通过 Node 类的 args
属性表示对其他节点的直接引用。使用相同的 FX 机制,我们可以表示计算图通常需要的以下操作,例如运算符调用、占位符(也称为输入)、条件语句和循环。
节点具有以下模式
class Node:
name: str # name of node
op_name: str # type of operation
# interpretation of the fields below depends on op_name
target: [str|Callable]
args: List[object]
kwargs: Dict[str, object]
meta: Dict[str, object]
FX 文本格式
如上例所示,请注意每行都具有以下格式
%<name>:[...] = <op_name>[target=<target>](args = (%arg1, %arg2, arg3, arg4, …)), kwargs = {"keyword": arg5})
此格式以紧凑的格式捕获 Node 类中存在的所有内容,除了 meta
之外。
具体来说
<name> 是节点的名称,如在
node.name
中所示。<op_name> 是
node.op
字段,它必须是以下之一:<call_function>、<placeholder>、<get_attr> 或 <output>。<target> 是节点的目标,作为
node.target
。此字段的含义取决于op_name
。args1, … args 4… 是
node.args
元组中列出的内容。如果列表中的值是torch.fx.Node
,那么它将用前导的 **%.** 特别地表示。
例如,对 add 运算符的调用将显示为
%add1 = call_function[target = torch.op.aten.add.Tensor](args = (%x, %y), kwargs = {})
其中 %x
、%y
是两个其他具有名称 x 和 y 的节点。值得注意的是,字符串 torch.op.aten.add.Tensor
代表存储在 target 字段中的可调用对象,而不仅仅是它的字符串名称。
此文本格式的最后一行是
return [add]
它是一个具有 op_name = output
的节点,表示我们正在返回此元素。
call_function¶
一个 call_function
节点表示对运算符的调用。
定义
函数式: 我们说一个可调用对象是“函数式”的,如果它满足以下所有要求
非变异:运算符不会改变其输入的值(对于张量,这包括元数据和数据)。
无副作用:运算符不会改变从外部可见的状态,例如更改模块参数的值。
运算符: 是一个具有预定义模式的函数式可调用对象。此类运算符的示例包括函数式 ATen 运算符。
FX 中的表示
%name = call_function[target = operator](args = (%x, %y, …), kwargs = {})
与普通 FX call_function 的区别
在 FX 图中,call_function 可以引用任何可调用对象,在导出 IR 中,我们将其限制为 ATen 运算符、自定义运算符和控制流运算符的选定子集。
在导出 IR 中,常量参数将嵌入在图中。
在 FX 图中,get_attr 节点可以表示读取存储在图模块中的任何属性。但是,在导出 IR 中,这仅限于读取子模块,因为所有参数/缓冲区都将作为输入传递给图模块。
元数据¶
Node.meta
是附加到每个 FX 节点的字典。但是,FX 规范没有指定可以或将存在哪些元数据。导出 IR 提供了一个更强的契约,具体来说,所有 call_function
节点将保证拥有并且仅拥有以下元数据字段
node.meta["stack_trace"]
是一个包含引用原始 Python 源代码的 Python 堆栈跟踪的字符串。一个示例堆栈跟踪如下所示File "my_module.py", line 19, in forward return x + dummy_helper(y) File "helper_utility.py", line 89, in dummy_helper return y + 1
node.meta["val"]
描述运行操作的输出。它可以是类型 <symint>、<FakeTensor>、List[Union[FakeTensor, SymInt]]
或None
。node.meta["nn_module_stack"]
描述节点来自的torch.nn.Module
的“堆栈跟踪”,如果它来自torch.nn.Module
调用。例如,如果包含addmm
操作的节点来自torch.nn.Linear
模块内部的torch.nn.Sequential
模块,则nn_module_stack
将类似于{'self_linear': ('self.linear', <class 'torch.nn.Linear'>), 'self_sequential': ('self.sequential', <class 'torch.nn.Sequential'>)}
node.meta["source_fn_stack"]
包含在分解之前调用此节点的 torch 函数或叶子torch.nn.Module
类。例如,包含来自torch.nn.Linear
模块调用的addmm
操作的节点将在其source_fn
中包含torch.nn.Linear
,而包含来自torch.nn.functional.Linear
模块调用的addmm
操作的节点将在其source_fn
中包含torch.nn.functional.Linear
。
placeholder¶
Placeholder 表示图的输入。它的语义与 FX 中的语义完全相同。Placeholder 节点必须是图的节点列表中的前 N 个节点。N 可以为零。
FX 中的表示
%name = placeholder[target = name](args = ())
target 字段是一个字符串,它是输入的名称。
args
,如果不为空,应大小为 1,表示此输入的默认值。
元数据
Placeholder 节点也具有 meta[‘val’]
,就像 call_function
节点一样。在这种情况下,val
字段表示图预计为该输入参数接收的输入形状/数据类型。
output¶
一个 output 调用表示函数中的 return 语句;因此它终止当前图。只有一个 output 节点,它始终是图的最后一个节点。
FX 中的表示
output[](args = (%something, …))
这与 torch.fx
中的语义完全相同。 args
表示要返回的节点。
元数据
Output 节点与 call_function
节点具有相同的元数据。
get_attr¶
get_attr
节点表示从封装的 torch.fx.GraphModule
中读取子模块。与 torch.fx.symbolic_trace()
中的普通 FX 图不同,在普通 FX 图中,get_attr
节点用于读取顶级 torch.fx.GraphModule
的属性(例如参数和缓冲区),参数和缓冲区将作为输入传递给图模块,并存储在顶级 torch.export.ExportedProgram
中。
FX 中的表示
%name = get_attr[target = name](args = ())
示例
考虑以下模型
from functorch.experimental.control_flow import cond
def true_fn(x):
return x.sin()
def false_fn(x):
return x.cos()
def f(x, y):
return cond(y, true_fn, false_fn, [x])
图
graph():
%x_1 : [num_users=1] = placeholder[target=x_1]
%y_1 : [num_users=1] = placeholder[target=y_1]
%true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
%false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
%conditional : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%y_1, %true_graph_0, %false_graph_0, [%x_1]), kwargs = {})
return conditional
该行,%true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
,读取子模块 true_graph_0
,其中包含 sin
运算符。
引用¶
SymInt¶
SymInt 是一个对象,它可以是字面量整数,也可以是表示整数的符号(在 Python 中由 sympy.Symbol
类表示)。当 SymInt 是一个符号时,它描述了在编译时对图未知的类型为整数的变量,也就是说,它的值仅在运行时已知。
FakeTensor¶
FakeTensor 是一个包含张量元数据的对象。它可以被视为具有以下元数据。
class FakeTensor:
size: List[SymInt]
dtype: torch.dtype
device: torch.device
dim_order: List[int] # This doesn't exist yet
FakeTensor 的 size 字段是一个整数或 SymInt 列表。如果存在 SymInt,则意味着此张量具有动态形状。如果存在整数,则假设该张量将具有该确切的静态形状。TensorMeta 的秩永远不会是动态的。dtype 字段表示该节点输出的数据类型。Edge IR 中没有隐式类型提升。FakeTensor 中没有步幅。
换句话说
如果 node.target 中的运算符返回一个张量,那么
node.meta['val']
是一个 FakeTensor,描述该张量。如果 node.target 中的运算符返回一个 n 元组的张量,那么
node.meta['val']
是一个 n 元组的 FakeTensor,描述每个张量。如果 node.target 中的运算符返回一个在编译时已知的 int/float/标量,那么
node.meta['val']
为 None。如果节点 target 中的操作符返回一个在编译时未知的 int/float/标量,那么
node.meta['val']
的类型为 SymInt。
例如
aten::add
返回一个张量;因此,它的规格将是一个 FakeTensor,其数据类型和大小为该操作符返回的张量。aten::sym_size
返回一个整数;因此,它的 val 将是一个 SymInt,因为它的值只有在运行时才能获得。max_pool2d_with_indexes
返回一个 (张量, 张量) 元组;因此,规格也将是一个包含两个 FakeTensor 对象的元组,第一个 TensorMeta 描述返回值的第一个元素等。
Python 代码
def add_one(x):
return torch.ops.aten(x, 1)
图
graph():
%ph_0 : [#users=1] = placeholder[target=ph_0]
%add_tensor : [#users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%ph_0, 1), kwargs = {})
return [add_tensor]
FakeTensor
FakeTensor(dtype=torch.int, size=[2,], device=CPU)
可 Pytree 化类型¶
我们定义了一个类型“可 Pytree 化”,如果它是一个叶子类型或一个包含其他可 Pytree 化类型的容器类型。
注意
Pytree 的概念与 JAX 文档中描述的相同 这里
以下类型被定义为 **叶子类型**
类型 |
定义 |
---|---|
张量 |
|
标量 |
来自 Python 的任何数值类型,包括整型、浮点型和零维张量。 |
int |
Python int(在 C++ 中绑定为 int64_t) |
float |
Python float(在 C++ 中绑定为 double) |
bool |
Python bool |
str |
Python 字符串 |
ScalarType |
|
布局 |
|
内存格式 |
|
设备 |
以下类型被定义为 **容器类型**
类型 |
定义 |
---|---|
元组 |
Python 元组 |
列表 |
Python 列表 |
字典 |
带有标量键的 Python 字典 |
命名元组 |
Python 命名元组 |
数据类 |
必须通过 register_dataclass 注册 |
自定义类 |
使用 _register_pytree_node 定义的任何自定义类 |