torch.export IR 规范¶
Export IR 是编译器的中间表示 (IR),与 MLIR 和 TorchScript 类似。它专门设计用于表示 PyTorch 程序的语义。Export IR 主要以简化的操作列表表示计算,对动态性(如控制流)的支持有限。
要创建 Export IR 图,可以使用前端通过跟踪专门机制可靠地捕获 PyTorch 程序。然后,可以通过后端优化和执行生成的 Export IR。这可以通过 torch.export.export()
来完成。
本文档中将涵盖的关键概念包括
ExportedProgram:包含导出 IR 程序的数据结构
Graph:由节点列表组成。
Nodes:表示存储在此节点上的操作、控制流和元数据。
值由节点产生并消耗。
类型与值和节点相关联。
还定义了值的大小和内存布局。
什么是导出 IR¶
导出 IR 是 PyTorch 程序的基于图的中间表示 IR。导出 IR 在 torch.fx.Graph
的基础上实现。换句话说,所有导出 IR 图也是有效的 FX 图,如果使用标准 FX 语义进行解释,则可以合理地解释导出 IR。一个含义是,可以通过标准 FX 代码生成将导出的图转换为有效的 Python 程序。
本说明文档将主要集中于突出显示导出 IR 在严格性方面与 FX 不同之处,同时跳过它与 FX 相似之处。
ExportedProgram¶
顶级导出 IR 构造是 torch.export.ExportedProgram
类。它将 PyTorch 模型(通常是 torch.nn.Module
)的计算图与该模型消耗的参数或权重捆绑在一起。
torch.export.ExportedProgram
类的某些显着属性是
graph_module
(torch.fx.GraphModule
): 包含 PyTorch 模型的扁平化计算图的数据结构。可以通过 ExportedProgram.graph 直接访问该图。graph_signature
(torch.export.ExportGraphSignature
): 图形签名,指定图形中使用的和突变的参数和缓冲区名称。不是将参数和缓冲区存储为图形的属性,而是将其提升为图形的输入。graph_signature 用于跟踪这些参数和缓冲区的附加信息。state_dict
(Dict[str, Union[torch.Tensor, torch.nn.Parameter]]
): 包含参数和缓冲区的数据结构。range_constraints
(Dict[sympy.Symbol, RangeConstraint]
): 对于导出具有数据相关行为的程序,每个节点的元数据将包含符号形状(看起来像s0
、i0
)。此属性将符号形状映射到其下/上限。
图形¶
导出 IR 图形是以 DAG(有向无环图)形式表示的 PyTorch 程序。此图形中的每个节点表示特定的计算或操作,此图形的边由节点之间的引用组成。
我们可以查看具有此架构的图形
class Graph:
nodes: List[Node]
实际上,导出 IR 的图形实现为 torch.fx.Graph
Python 类。
导出 IR 图形包含以下节点(将在下一部分中更详细地描述节点)
0 个或更多 op 类型为
placeholder
的节点0 个或更多 op 类型为
call_function
的节点恰好 1 个 op 类型为
output
的节点
推论:最小的有效图形将由一个节点组成。即节点永远不会为空。
定义:图形的 placeholder
节点集合表示 GraphModule 的图形的输入。图形的 output 节点表示 GraphModule 的图形的输出。
示例
from torch import nn
class MyModule(nn.Module):
def forward(self, x, y):
return x + y
mod = torch.export.export(MyModule())
print(mod.graph)
上面是图的文本表示,其中每一行都是一个节点。
节点¶
节点表示特定的计算或操作,并使用 torch.fx.Node
类在 Python 中表示。节点之间的边通过 Node 类的 args
属性表示为对其他节点的直接引用。使用相同的 FX 机制,我们可以表示计算图通常需要的以下操作,例如运算符调用、占位符(又称输入)、条件和循环。
节点具有以下模式
class Node:
name: str # name of node
op_name: str # type of operation
# interpretation of the fields below depends on op_name
target: [str|Callable]
args: List[object]
kwargs: Dict[str, object]
meta: Dict[str, object]
FX 文本格式
如上例所示,请注意每一行都具有此格式
%<name>:[...] = <op_name>[target=<target>](args = (%arg1, %arg2, arg3, arg4, …)), kwargs = {"keyword": arg5})
此格式以紧凑的格式捕获 Node 类中存在的所有内容,但 meta
除外。
具体来说
<name> 是节点的名称,它将显示在
node.name
中。<op_name> 是
node.op
字段,它必须是以下之一:<call_function>、<placeholder>、<get_attr> 或 <output>。<target> 是节点的目标,如
node.target
。此字段的含义取决于op_name
。args1、…args 4… 是
node.args
元组中列出的内容。如果列表中的值是torch.fx.Node
,那么它将特别用前导 %. 指示。
例如,对 add 运算符的调用将显示为
%add1 = call_function[target = torch.op.aten.add.Tensor](args = (%x, %y), kwargs = {})
其中 %x
、%y
是另外两个具有名称 x 和 y 的节点。值得注意的是,字符串 torch.op.aten.add.Tensor
表示实际存储在 target 字段中的可调用对象,而不仅仅是它的字符串名称。
此文本格式的最后一行是
return [add]
这是一个具有 op_name = output
的节点,表示我们正在返回此一个元素。
call_function¶
call_function
节点表示对运算符的调用。
定义
功能:如果可调用项满足以下所有要求,则称其为“功能”
非突变:运算符不会改变其输入的值(对于张量,这包括元数据和数据)。
无副作用:运算符不会改变从外部可见的状态,例如更改模块参数的值。
运算符:是具有预定义架构的功能可调用项。此类运算符的示例包括功能性 ATen 运算符。
在 FX 中表示
%name = call_function[target = operator](args = (%x, %y, …), kwargs = {})
与香草 FX call_function 的区别
在 FX 图中,call_function 可以引用任何可调用项,在导出 IR 中,我们将其限制为 ATen 运算符、自定义运算符和控制流运算符的选定子集。
在导出 IR 中,常量参数将嵌入在图中。
在 FX 图中,get_attr 节点可以表示读取存储在图模块中的任何属性。但是,在导出 IR 中,这仅限于重新读取子模块,因为所有参数/缓冲区都将作为输入传递给图模块。
元数据¶
Node.meta
是附加到每个 FX 节点的字典。但是,FX 规范没有指定可以或将存在哪些元数据。导出 IR 提供了更强的契约,特别是所有 call_function
节点将保证具有且仅具有以下元数据字段
node.meta["stack_trace"]
是一个包含 Python 堆栈跟踪的字符串,引用原始 Python 源代码。示例堆栈跟踪如下所示File "my_module.py", line 19, in forward return x + dummy_helper(y) File "helper_utility.py", line 89, in dummy_helper return y + 1
node.meta["val"]
描述了运行操作的输出。它可以是类型 <symint>、<FakeTensor>、List[Union[FakeTensor, SymInt]]
或None
。node.meta["nn_module_stack"]
描述了torch.nn.Module
的“堆栈跟踪”,如果该节点来自torch.nn.Module
调用,则来自该torch.nn.Module
。例如,如果一个包含addmm
操作的节点从torch.nn.Linear
模块中torch.nn.Sequential
模块内部调用,则nn_module_stack
看起来像{'self_linear': ('self.linear', <class 'torch.nn.Linear'>), 'self_sequential': ('self.sequential', <class 'torch.nn.Sequential'>)}
node.meta["source_fn_stack"]
包含分解之前的 torch 函数或叶torch.nn.Module
类,该节点从中调用。例如,一个包含来自torch.nn.Linear
模块调用的addmm
操作的节点将包含torch.nn.Linear
在其source_fn
中,而一个包含来自torch.nn.functional.Linear
模块调用的addmm
操作的节点将包含torch.nn.functional.Linear
在其source_fn
中。
占位符¶
占位符表示图的输入。其语义与 FX 中完全相同。占位符节点必须是图的节点列表中的前 N 个节点。N 可以为零。
在 FX 中表示
%name = placeholder[target = name](args = ())
target 字段是一个字符串,表示输入的名称。
args
(如果非空)的大小应为 1,表示此输入的默认值。
元数据
占位符节点还具有 meta[‘val’]
,如 call_function
节点。在这种情况下,val
字段表示图预期为此输入参数接收的输入形状/数据类型。
output¶
输出调用表示函数中的 return 语句;因此它会终止当前图。只有一个输出节点,它始终是图的最后一个节点。
在 FX 中表示
output[](args = (%something, …))
这与 torch.fx
中的语义完全相同。 args
表示要返回的节点。
元数据
输出节点具有与 call_function
节点相同的元数据。
get_attr¶
get_attr
节点表示从封装的 torch.fx.GraphModule
读取子模块。与 torch.fx.symbolic_trace()
中的香草 FX 图不同,在香草 FX 图中,get_attr
节点用于从顶级 torch.fx.GraphModule
读取诸如参数和缓冲区之类的属性,而参数和缓冲区被作为输入传递到图模块中,并存储在顶级 torch.export.ExportedProgram
中。
在 FX 中表示
%name = get_attr[target = name](args = ())
示例
考虑以下模型
from functorch.experimental.control_flow import cond
def true_fn(x):
return x.sin()
def false_fn(x):
return x.cos()
def f(x, y):
return cond(y, true_fn, false_fn, [x])
图
graph():
%x_1 : [num_users=1] = placeholder[target=x_1]
%y_1 : [num_users=1] = placeholder[target=y_1]
%true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
%false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
%conditional : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%y_1, %true_graph_0, %false_graph_0, [%x_1]), kwargs = {})
return conditional
行 %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
读取包含 sin
运算符的子模块 true_graph_0
。
引用¶
SymInt¶
SymInt 是一个对象,它可以是文字整数,也可以是表示整数的符号(在 Python 中由 sympy.Symbol
类表示)。当 SymInt 是符号时,它描述了一个在编译时对图未知的整数类型变量,即,其值仅在运行时已知。
FakeTensor¶
FakeTensor 是一个包含张量元数据的对象。可以将它视为具有以下元数据。
class FakeTensor:
size: List[SymInt]
dtype: torch.dtype
device: torch.device
dim_order: List[int] # This doesn't exist yet
FakeTensor 的大小字段是整数或 SymInt 的列表。如果存在 SymInt,则表示此张量具有动态形状。如果存在整数,则假定张量将具有该确切的静态形状。TensorMeta 的秩永远不是动态的。dtype 字段表示该节点输出的数据类型。Edge IR 中没有隐式类型提升。FakeTensor 中没有步幅。
换句话说
如果 node.target 中的运算符返回张量,则
node.meta['val']
是描述该张量的 FakeTensor。如果 node.target 中的运算符返回张量的 n 元组,则
node.meta['val']
是描述每个张量的 FakeTensor 的 n 元组。如果 node.target 中的运算符返回在编译时已知的 int/float/标量,则
node.meta['val']
为 None。如果 node.target 中的运算符返回在编译时未知的 int/float/标量,则
node.meta['val']
的类型为 SymInt。
例如
aten::add
返回张量;因此,其规格将是 FakeTensor,其中 dtype 和 size 为此运算符返回的张量的大小。aten::sym_size
返回整数;因此,其 val 将是 SymInt,因为其值仅在运行时可用。max_pool2d_with_indexes
返回 (Tensor, Tensor) 元组;因此,规格也将是 FakeTensor 对象的 2 元组,第一个 TensorMeta 描述返回值的第一个元素等。
Python 代码
def add_one(x):
return torch.ops.aten(x, 1)
图
graph():
%ph_0 : [#users=1] = placeholder[target=ph_0]
%add_tensor : [#users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%ph_0, 1), kwargs = {})
return [add_tensor]
FakeTensor
FakeTensor(dtype=torch.int, size=[2,], device=CPU)
可 Pytree 的类型¶
如果它是一个叶类型或包含其他可 Pytree 类型的容器类型,则我们定义一个“可 Pytree”类型。
注意
pytree 的概念与 此处 为 JAX 记录的概念相同
以下类型被定义为叶类型
类型 |
定义 |
---|---|
Tensor |
|
标量 |
Python 中的任何数值类型,包括整数类型、浮点类型和零维张量。 |
int |
Python int(在 C++ 中绑定为 int64_t) |
float |
Python float(在 C++ 中绑定为 double) |
bool |
Python bool |
str |
Python 字符串 |
ScalarType |
|
Layout |
|
MemoryFormat |
|
Device |
以下类型被定义为容器类型
类型 |
定义 |
---|---|
Tuple |
Python 元组 |
List |
Python 列表 |
Dict |
具有标量键的 Python 字典 |
NamedTuple |
Python 命名元组 |
Dataclass |
必须通过 register_dataclass 注册 |
自定义类 |
使用 _register_pytree_node 定义的任何自定义类 |