torch.export IR 规范¶
Export IR 是一个用于编译器的中间表示 (IR),它与 MLIR 和 TorchScript 有相似之处。它专门设计用于表达 PyTorch 程序的语义。Export IR 主要以简化的操作列表形式表示计算,对动态机制(如控制流)的支持有限。
要创建 Export IR 图,可以使用前端,该前端通过跟踪专业化机制来可靠地捕获 PyTorch 程序。然后,后端可以优化和执行生成的 Export IR。这可以通过今天的 torch.export.export()
来完成。
本文档将涵盖的关键概念包括
ExportedProgram:包含 Export IR 程序的数据结构
图:由节点列表组成。
节点:代表操作、控制流和存储在此节点上的元数据。
值由节点产生和使用。
类型与值和节点关联。
还定义了值的尺寸和内存布局。
什么是 Export IR¶
Export IR 是 PyTorch 程序的基于图的中间表示 IR。Export IR 是在 torch.fx.Graph
之上实现的。换句话说,所有 Export IR 图都是有效的 FX 图,如果使用标准 FX 语义解释,Export IR 可以被可靠地解释。一个含义是,可以通过标准 FX 代码生成将导出的图转换为有效的 Python 程序。
本文档将主要侧重于突出显示 Export IR 在严格性方面与 FX 的不同之处,同时跳过与 FX 共有相似之处的部分。
ExportedProgram¶
顶级 Export IR 构造是一个 torch.export.ExportedProgram
类。它将 PyTorch 模型(通常是 torch.nn.Module
)的计算图捆绑在一起,以及该模型使用的参数或权重。
torch.export.ExportedProgram
类的一些值得注意的属性是
graph_module
(torch.fx.GraphModule
): 包含 PyTorch 模型的扁平化计算图的数据结构。可以通过 ExportedProgram.graph 直接访问图。graph_signature
(torch.export.ExportGraphSignature
): 图签名,它指定了图内使用和修改的参数和缓冲区名称。参数和缓冲区不是存储为图的属性,而是作为图的输入提升的。graph_signature 用于跟踪这些参数和缓冲区的附加信息。state_dict
(Dict[str, Union[torch.Tensor, torch.nn.Parameter]]
): 包含参数和缓冲区的数据结构。range_constraints
(Dict[sympy.Symbol, RangeConstraint]
): 对于使用数据相关行为导出的程序,每个节点上的元数据将包含符号形状(看起来像s0
、i0
)。此属性将符号形状映射到它们的下限/上限范围。
图¶
导出 IR 图是使用 DAG(有向无环图)形式表示的 PyTorch 程序。图中的每个节点表示一个特定的计算或操作,图的边表示节点之间的引用。
我们可以将图视为具有以下结构的图
class Graph:
nodes: List[Node]
在实践中,导出 IR 图是通过 torch.fx.Graph
Python 类实现的。
导出 IR 图包含以下节点(节点将在下一节中详细介绍)
0 个或多个操作类型为
placeholder
的节点0 个或多个操作类型为
call_function
的节点正好 1 个操作类型为
output
的节点
推论:最小的有效图将只有一个节点。即节点列表永远不为空。
定义:图中 placeholder
节点集表示 GraphModule 图的输入。图中的 output 节点表示 GraphModule 图的输出。
示例
from torch import nn
class MyModule(nn.Module):
def forward(self, x, y):
return x + y
mod = torch.export.export(MyModule())
print(mod.graph)
以上是图的文本表示,每行都是一个节点。
节点¶
节点表示一个特定的计算或操作,使用 Python 中的 torch.fx.Node
类表示。节点之间的边表示为通过 Node 类的 args
属性对其他节点的直接引用。使用相同的 FX 机制,我们可以表示计算图通常需要的以下操作,例如操作符调用、占位符(又名输入)、条件语句和循环。
节点具有以下结构
class Node:
name: str # name of node
op_name: str # type of operation
# interpretation of the fields below depends on op_name
target: [str|Callable]
args: List[object]
kwargs: Dict[str, object]
meta: Dict[str, object]
FX 文本格式
如上面的示例所示,请注意每行都具有以下格式
%<name>:[...] = <op_name>[target=<target>](args = (%arg1, %arg2, arg3, arg4, …)), kwargs = {"keyword": arg5})
此格式以紧凑的格式捕获 Node 类中存在的所有内容,除了 meta
。
具体来说
<name> 是节点的名称,它将出现在
node.name
中。<op_name> 是
node.op
字段,它必须是以下之一:<call_function>、<placeholder>、<get_attr> 或 <output>。<target> 是节点的目标,作为
node.target
。此字段的含义取决于op_name
。args1, … args 4… 是
node.args
元组中列出的内容。如果列表中的值是torch.fx.Node
,则它将以一个前导的 %. 特别表示。
例如,对 add 操作符的调用将显示为
%add1 = call_function[target = torch.op.aten.add.Tensor](args = (%x, %y), kwargs = {})
其中 %x
、%y
是两个其他节点,它们的名称分别为 x 和 y。值得注意的是,字符串 torch.op.aten.add.Tensor
表示实际上存储在 target 字段中的可调用对象,而不仅仅是它的字符串名称。
此文本格式的最后一行是
return [add]
这是一个具有 op_name = output
的节点,表明我们正在返回此元素。
call_function¶
一个 call_function
节点表示对一个操作符的调用。
定义
函数式:我们说一个可调用对象是“函数式”的,如果它满足以下所有要求
非变异:操作符不改变其输入的值(对于张量,这包括元数据和数据)。
没有副作用:操作符不会改变从外部可见的状态,例如改变模块参数的值。
操作符:是一个具有预定义结构的函数式可调用对象。此类操作符的示例包括函数式 ATen 操作符。
在 FX 中的表示
%name = call_function[target = operator](args = (%x, %y, …), kwargs = {})
与普通 FX call_function 的区别
在 FX 图中,一个 call_function 可以引用任何可调用对象,而在导出 IR 中,我们将其限制为 ATen 操作符、自定义操作符和控制流操作符的特定子集。
在导出 IR 中,常量参数将嵌入到图中。
在 FX 图中,一个 get_attr 节点可以表示读取存储在图模块中的任何属性。但是,在导出 IR 中,这仅限于读取子模块,因为所有参数/缓冲区都将作为输入传递给图模块。
元数据¶
Node.meta
是附加到每个 FX 节点的字典。但是,FX 规范没有指定元数据可以或将存在的内容。导出 IR 提供了一个更强的约定,具体来说,所有 call_function
节点将保证具有并且仅具有以下元数据字段
node.meta["stack_trace"]
是一个字符串,包含引用原始 Python 源代码的 Python 堆栈跟踪。一个示例堆栈跟踪看起来像File "my_module.py", line 19, in forward return x + dummy_helper(y) File "helper_utility.py", line 89, in dummy_helper return y + 1
node.meta["val"]
描述运行操作的输出。它可以是 <symint>、<FakeTensor>、List[Union[FakeTensor, SymInt]]
或None
类型的。node.meta["nn_module_stack"]
描述来自节点的torch.nn.Module
的“堆栈跟踪”,如果它是来自torch.nn.Module
调用的。例如,如果包含addmm
操作的节点来自torch.nn.Linear
模块内部的一个torch.nn.Sequential
模块,则nn_module_stack
将看起来像{'self_linear': ('self.linear', <class 'torch.nn.Linear'>), 'self_sequential': ('self.sequential', <class 'torch.nn.Sequential'>)}
node.meta["source_fn_stack"]
包含在分解之前调用此节点的 torch 函数或叶torch.nn.Module
类。例如,包含来自torch.nn.Linear
模块调用的addmm
操作的节点将在其source_fn
中包含torch.nn.Linear
,而包含来自torch.nn.functional.Linear
模块调用的addmm
操作的节点将在其source_fn
中包含torch.nn.functional.Linear
。
placeholder¶
占位符表示图的输入。它的语义与 FX 中完全相同。占位符节点必须是图的节点列表中的前 N 个节点。N 可以为零。
在 FX 中的表示
%name = placeholder[target = name](args = ())
target 字段是一个字符串,表示输入的名称。
args
(如果非空)应该为大小为 1,表示此输入的默认值。
元数据
占位符节点也有 meta[‘val’]
,就像 call_function
节点一样。在这种情况下,val
字段表示图期望为该输入参数接收的输入形状/数据类型。
output¶
输出调用表示函数中的 return 语句;因此它终止当前图。有一个且只有一个输出节点,它将始终是图的最后一个节点。
在 FX 中的表示
output[](args = (%something, …))
这具有与 torch.fx
中完全相同的语义。 args
表示要返回的节点。
元数据
输出节点具有与 call_function
节点相同的元数据。
get_attr¶
get_attr
节点表示从封装的 torch.fx.GraphModule
中读取子模块。与来自 torch.fx.symbolic_trace()
的普通 FX 图不同,在普通 FX 图中,get_attr
节点用于从顶层 torch.fx.GraphModule
读取参数和缓冲区,参数和缓冲区作为输入传递给图模块,并存储在顶层 torch.export.ExportedProgram
中。
在 FX 中的表示
%name = get_attr[target = name](args = ())
示例
考虑以下模型
from functorch.experimental.control_flow import cond
def true_fn(x):
return x.sin()
def false_fn(x):
return x.cos()
def f(x, y):
return cond(y, true_fn, false_fn, [x])
图
graph():
%x_1 : [num_users=1] = placeholder[target=x_1]
%y_1 : [num_users=1] = placeholder[target=y_1]
%true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
%false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
%conditional : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%y_1, %true_graph_0, %false_graph_0, [%x_1]), kwargs = {})
return conditional
这一行,%true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
,读取子模块 true_graph_0
,该子模块包含 sin
操作符。
引用¶
SymInt¶
SymInt 是一个对象,它可以是字面量整数,也可以是表示整数的符号(在 Python 中由 sympy.Symbol
类表示)。当 SymInt 是符号时,它描述了编译时对图未知的整数类型变量,也就是说,它的值只在运行时才知道。
FakeTensor¶
FakeTensor 是一个包含张量元数据的对象。可以将其视为具有以下元数据。
class FakeTensor:
size: List[SymInt]
dtype: torch.dtype
device: torch.device
dim_order: List[int] # This doesn't exist yet
FakeTensor 的 size 字段是一个整数或 SymInt 列表。如果存在 SymInt,则表示此张量具有动态形状。如果存在整数,则假定张量将具有该确切的静态形状。TensorMeta 的秩永远不会是动态的。dtype 字段表示该节点输出的 dtype。Edge IR 中没有隐式类型提升。FakeTensor 中没有步长。
换句话说
如果 node.target 中的操作符返回一个张量,则
node.meta['val']
是一个 FakeTensor,描述了该张量。如果 node.target 中的操作符返回一个 n 元组的张量,则
node.meta['val']
是一个 n 元组的 FakeTensor,描述了每个张量。如果 node.target 中的操作符返回一个编译时已知的 int/float/标量,则
node.meta['val']
为 None。如果 node.target 中的操作符返回一个编译时未知的 int/float/标量,则
node.meta['val']
是 SymInt 类型。
例如
aten::add
返回一个张量;因此它的规范将是一个 FakeTensor,具有该操作符返回的张量的 dtype 和大小。aten::sym_size
返回一个整数;因此它的 val 将是一个 SymInt,因为它的值只有在运行时才能获得。max_pool2d_with_indexes
返回一个 (Tensor, Tensor) 元组;因此规范也将是两个 FakeTensor 对象的 2 元组,第一个 TensorMeta 描述了返回值的第一个元素,等等。
Python 代码
def add_one(x):
return torch.ops.aten(x, 1)
图
graph():
%ph_0 : [#users=1] = placeholder[target=ph_0]
%add_tensor : [#users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%ph_0, 1), kwargs = {})
return [add_tensor]
FakeTensor
FakeTensor(dtype=torch.int, size=[2,], device=CPU)
Pytree 可用类型¶
我们定义一个类型“Pytree 可用”,如果它是叶子类型或包含其他 Pytree 可用类型的容器类型。
注意
pytree 的概念与 JAX 中记录的 此处 相同
以下类型被定义为 **叶子类型**
类型 |
定义 |
---|---|
张量 |
|
标量 |
来自 Python 的任何数值类型,包括整数类型、浮点类型和零维张量。 |
int |
Python int(在 C++ 中绑定为 int64_t) |
float |
Python float(在 C++ 中绑定为 double) |
bool |
Python bool |
str |
Python 字符串 |
ScalarType |
|
布局 |
|
内存格式 |
|
设备 |
以下类型被定义为 **容器类型**
类型 |
定义 |
---|---|
元组 |
Python 元组 |
列表 |
Python 列表 |
字典 |
具有标量键的 Python 字典 |
命名元组 |
Python 命名元组 |
数据类 |
必须通过 register_dataclass 注册 |
自定义类 |
使用 _register_pytree_node 定义的任何自定义类 |