自定义编译器 Pass 和分区器¶
Pass¶
Pass 可以大致分为几个轴
轴 A
创建一对多映射(例如,分解)
创建多对一映射(例如,融合)
轴 B
执行前向迭代(例如,形状传播)
执行后向迭代(例如,死代码消除)
轴 C
依赖于本地节点信息(例如,输出变量转换)
依赖于全局图信息(例如,内存规划)
我们对这些用例频率的预测是
A.1、B.1、C.1
A.2
B.2、C.2
级别 1¶
对于级别 1 用例(创建一对多映射、执行前向迭代和查看本地节点信息),我们可以利用一个名为 ExportPass
的辅助类。这是一种 基于解释器 的方法,我们可以在其中执行每个节点并重新创建图,但会进行指定的转换。这使我们能够通过确保在 Pass 中创建的所有节点都满足 IR 规范(包括确保保留堆栈跟踪、FakeTensor 值和 torch.nn.Module 层次结构,并根据所做的转换进行更新)来保留 IR 规范。
要实现此 Pass,我们可以创建 ExportPass
的子类并实现公开的函数。当使用图模块调用时,它将运行图模块并创建一个包含 Pass 指定更改的新图。这意味着传入的图模块必须可在 CPU 上运行,并且此不变量将在 Pass 运行后保持不变。
一对一 Pass¶
对于一对一映射的示例,如果我们想用另一个运算符 B 替换运算符 A,我们可以运行给定的 fx.GraphModule
,并且每次我们看到运算符 A 时,都返回运算符 B。
考虑以下示例
class ReplaceInPlaceReluWithOutOfPlaceReluPass(ExportPass):
"""
relu_ is the in-place version. Replace it with relu, which is the
out-of-place version
"""
def call_operator(self, op, args, kwargs, meta):
if op != torch.ops.aten.relu_.default:
return super().call_operator(op, args, kwargs, meta)
return super().call_operator(Op(torch.ops.aten.relu.default), args, kwargs, meta)
# To create a pass
replace_pass = ReplaceInPlaceReluWithOutOfPlaceReluPass()
# To run a pass
new_graph_module = replace_pass(graph_module).graph_module
super().call_operator(op, args, kwargs, meta)
调用创建一个 call_function
FX 节点,并返回使用给定参数运行运算符的结果。
一对多 Pass¶
如果我们想进行一对多映射,例如用 2 个其他运算符 B 和 C 替换运算符 A,那么我们将进行 2 次 super().call_operator
调用来创建 2 个 FX 节点,一个使用运算符 B,另一个使用运算符 C,并返回运行运算符 C 的结果。
例如
class ReplaceAddWithMulSub(ExportPass):
"""
Original:
def f(x, y):
return x + y
After pass:
def f(x, y):
z = x * y
return z - y
"""
def call_operator(self, op, args, kwargs, meta):
if op != torch.ops.aten.add.default:
return super().call_operator(op, args, kwargs, meta)
x, y = args
mul_res = super().call_operator(
torch.ops.aten.mul.default,
args,
{},
meta
)
return super().call_operator(
torch.ops.aten.sub.default,
(mul_res, y),
{},
meta
)
一对零 Pass¶
如果我们想删除一个运算符,我们可以只返回传递给函数的值
class RemoveDetachPass(ExportPass):
def call_operator(self, op, args, kwargs, meta):
if op not in (
torch.ops.aten.detach.default,
torch.ops.aten.detach_copy.default,
):
return super().call_operator(op, args, kwargs, meta)
assert len(args) == 1
return args[0]
利用本地信息¶
利用本地节点信息的示例是,如果我们想将图中的所有标量转换为张量,我们可以运行给定的 fx.GraphModule
,并且对于包含标量的每个参数,我们将其转换为张量。它可能看起来像这样
def args_map(op, fn, args, kwargs):
assert isinstance(args, tuple)
assert isinstance(kwargs, dict)
args = list(args)
kwargs = kwargs.copy()
# Update the argument based on the function passed
def update(key, args, schema):
args[key] = fn(args[key], schema)
# Update each argument in the schema
for i, schema in enumerate(self.op._schema.arguments):
if schema.name in kwargs:
update(schema.name, kwargs, schema)
elif not schema.kwarg_only and i < len(args):
update(i, args, schema)
class ScalarToTensorPass(ExportPass):
def call_operator(self, op, args, kwargs):
def try_coerce(value, arg):
return (
torch.tensor(value)
if isinstance(value, (float, int, bool))
and type(arg.type) == torch.TensorType
else value
)
args, kwargs = args_map(op, try_coerce, args, kwargs)
return super().call_operator(op, args, kwargs)
级别 2¶
为了创建多对一映射,我们可以利用 FX 的 子图重写器。给定一个 pattern
,它会创建一个与模式匹配的运算符子图,然后用 replacement
替换每个匹配的子图。
注意
This is an inplace operation.
pattern
和 replacement
输入必须是使用与您要匹配的 EXIR 图中使用的相同运算符(ATen 运算符)编写的可调用函数,以便子图重写器可以在图中找到正确的模式。模式/替换可调用对象的输入将被视为通配符。
考虑以下示例
from torch.fx import subgraph_rewriter
def replace_patterns(graph_module):
def pattern(x, y):
x = torch.ops.aten.add.Tensor(x, y)
x = torch.ops.aten.mul.Tensor(x, y)
return x
def replacement(x, y):
return torch.ops.aten.sub.Tensor(x, y)
replaced_patterns = subgraph_rewriter.replace_pattern_with_filters(
traced_module, pattern, replacement
)
子图重写器返回一个 ReplacedPatterns
列表
@dataclass
class ReplacedPatterns:
# Node from which the match was found
anchor: Node
# Maps nodes in the pattern subgraph to nodes in the larger graph
nodes_map: Dict[Node, Node]
# List of nodes that were added into the graph
replacements: List[Node]
注意
The nodes created by the subgraph rewriter will not have the metadata that
is normally in EXIR nodes (`stack_trace`, `val`, `nn_module_stack`).
级别 3¶
对于创建 Pass 的第三种方法,我们可以利用最基本的 PassBase
。要创建 Pass,我们可以继承此类并使用 Pass 内容实现函数 call
。此外,我们可以实现函数 requires
和 ensures
,它们将在函数 call
之前和之后调用。请注意,这些函数也可以在 ExportPass
中重写。要在图模块上运行 Pass,我们可以将图模块直接传递给类的实例。
考虑以下示例
class ReplaceAddPass(PassBase):
def __init__(self, replace_op):
self.replace_op = replace_op
def call(self, graph_module):
for node in gm.graph.nodes:
if node.op == "call_function" and node.target == torch.add:
node.target = self.replace_op
# Optional to implement, will be called before call()
def requires(self, graph_module) -> None:
for node in graph_module.graph.nodes:
if node.op == "call_function" and node.target == torch.add:
return
raise ValueError("No torch.add ops!")
# Optional to implement, will be called after call()
def ensures(self, graph_module: torch.fx.GraphModule) -> None:
pass
# To create a pass
replace_add_with_div = ReplaceAddPass(torch.div)
# To run a pass
replace_add_with_div(graph_module)
Pass 管理器¶
PassManager
是一个用于在给定图模块上运行多个 Pass 的类。初始化 PassManager
实例时,我们传入要运行的 Pass 列表并设置几个标志。要在图模块上运行 Pass 集合,我们可以将图模块直接传递给 PassManager
实例。
示例
from executorch.exir.pass_manager import PassManager
pm = PassManager(
passes=[replace_add_with_div, replace_div_with_mul],
run_checks_after_each_pass=True,
suppress_check_failures=False,
)
graph_module_out = pm(graph_module)
要添加在每个 Pass 之后运行的一组通用检查,我们可以调用函数 set_checks(check: Callable)
,该函数将可调用函数作为输入。如果设置了 run_checks_after_each_pass
标志,则在图模块上运行每个 Pass 后,将调用 check
。
示例
pm = PassManager(passes=[replace_add_with_div, replace_div_with_mul])
def check_div_target(graph_module):
for node in graph_module.graph.nodes:
if node.op == "call_function" and node.target != torch.div:
raise ValueError("Target should be div!")
pm.add_checks(check_div_target)
pm(graph_module) # raises ValueError after replace_div_with_mul pass
分区器¶
我们可以使用几个常见的基于 FX 图的分区器来对图进行分区。但是,这些分区器不一定会生成符合 IR 规范的图,因此在使用时请小心。
子图匹配器¶
为了在图中查找与特定模式匹配的子图,我们可以利用 FX 的 SubgraphMatcher
。
类属性
pattern (Graph)
:目标匹配模式。图中的占位符节点在匹配时将被视为通配符。match_output (bool)
:如果为 True,则模式图中的输出节点将被视为目标模式的一部分。如果为 False,则在匹配期间忽略输出节点。match_placeholder (bool)
:如果为 True,则模式图中的占位符节点将被视为目标模式的一部分。如果为 False,则占位符节点将用作通配符。remove_overlapping_matches (bool)
:如果为 True,则在重叠匹配的情况下,仅返回第一个匹配项。ignore_literals (bool)
:如果为 True,则不会检查字面量是否相等,而是将其视为通配符。
考虑以下示例
from torch.fx.passes.utils.matcher_utils import SubgraphMatcher
class LargeModel(torch.nn.Module):
def __init__(self):
super().__init__()
self._weight = torch.nn.Parameter(torch.ones(3, 3))
self._bias = torch.nn.Parameter(torch.ones(3, 3))
def forward(self, x):
return torch.ops.aten.addmm.default(self._bias, x, self._weight)
large_model_graph = to_edge(export(LargeModel(), large_inputs)).exported_program().graph_module.graph
class PatternModel(torch.nn.Module):
def __init__(self):
super().__init__()
self._weight_1 = torch.nn.Parameter(torch.ones(5, 5))
self._bias_1 = torch.nn.Parameter(torch.ones(5, 5))
def forward(self, x):
return torch.ops.aten.addmm.default(self._bias_1, x, self._weight_1)
pattern_graph = to_edge(export(PatternModel(), pattern_inputs)).exported_program().graph_module.graph
subgraph_matcher = SubgraphMatcher(pattern_graph)
match_result = subgraph_matcher.match(large_model_graph)
match
函数返回一个 InternalMatch
列表
@dataclass
class InternalMatch():
# Nodes from which the match was found
anchors: List[Node]
# Maps nodes in the pattern subgraph to nodes in the larger graph
nodes_map: Dict[Node, Node] = field(default_factory=dict)
# Nodes in target graph that are matched placeholder in pattern
placeholder_nodes: List[Node] = field(default_factory=list)
# Nodes in matched subgraph returned by output
returning_nodes: List[Node] = field(default_factory=list)
基于能力的分区器¶
为了找到支持特定不变量的最大节点子图,我们可以利用 FX 的 CapabilityBasedPartitioner
。
类属性
graph_module (torch.fx.GraphModule)
:我们要在其上进行分区的图模块。operator_support (OperatorSupportBase)
:用于确定图中节点是否在分区中受支持的对象。allows_single_node_partition (bool)
:如果为 True,则允许形成单节点分区。non_compute_ops (Optional[Sequence[str]])
:一组被认为是“非计算”的运算符(例如torch.ops.aten.view
和_operator.getitem
),以便分区器不会创建仅包含这些非计算运算符的图allowed_single_node_partition_ops (Optional[Sequence[str]])
:一组允许在单节点分区中的运算符。
OperatorSupportBase
类由分区器用于确定图中特定节点是否属于分区。这是通过重写 is_node_supported
函数来完成的。您可以使用 chain
(如果任何 OperatorSupportBase 返回 False,则返回 False)和 any_chain
(如果任何 OperatorSupportBase 返回 True,则返回 True)链接多个 OperatorSuppportBase
。
考虑以下示例
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
from torch.fx.passes.operator_support import any_chain, OperatorSupportBase
class AddMulOperatorSupport(OperatorSupportBase):
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
return node.op == "call_function" and node.target in [
torch.ops.aten.add.Tensor, torch.ops.aten.mul.Tensor,
]
capability_partitioner = CapabilityBasedPartitioner(
graph_module,
op_support,
)
# Returns a list of partitions (list of nodes that belong in each partition)
partition_list = capability_partitioner.propose_partitions()
如果您查看基于能力的分区器,您可能还会找到一个 fuse_partition
函数,该函数将返回一个修改后的图,其中分区作为子模块,并通过 call_module
节点在顶层图中调用这些子模块。但是,这不符合 IR 规范,因为我们不允许 call_module
节点。
组合¶
我们还提供了一个组合辅助函数:generate_pattern_op_partitions
Args
graph_module (fx.GraphModule)
:我们要分区的模块patterns (List[torch.fx.Graph])
:torch.fx.Graph 格式的模式列表。这些图可以通过从 exir.capture(推荐)或符号跟踪(可能不会产生准确的边缘方言图)获得的 GraphModule 获取graph
字段来获得,或者通过手动制作图模块来获得。op_support (OperatorSupportBase)
:可以按以下方式创建的 OperatorSupportBase直接继承它并实现
is_node_supported()
获取
create_op_support()
的结果获取
create_pattern_support()
的结果使用
chain()
或any_chain()
链接在一起的多个 OperatorSupportBase 类
返回
包含给定 OperatorSupportBase 对象和给定模式图的并集支持的节点的分区列表(最大可能的子图)。
源分区器¶
对于更复杂的用例,用户希望基于更高级别的模块(torch.nn.Linear
或 torch.nn.functional.Linear
)进行分区,这些模块现在被分解为其运算符(aten.permute
、aten.addmm
),我们有以下 辅助函数
get_source_partitions(graph: torch.fx.Graph, wanted_sources: List[Any]) -> Dict[Any, SourcePartition]
Args
graph
:我们要分区的图wanted_sources
:从该源分解的节点源列表。这可以是函数(例如torch.nn.functional.linear
)或叶模块类型(例如torch.nn.Linear
)
返回
将源(例如
torch.nn.modules.linear.Linear
)映射到SourcePartitions
列表的字典,这些列表对应于从该类型的模块展平的节点列表。
@dataclass
class SourcePartition():
# Nodes in a particular partition
nodes: List[Node]
# Module type
module_type: Type
# Nodes in the graph that are needed as inputs to the partition
input_nodes: List[Node] = field(default_factory=list)
# Nodes in the partition that are being used by nodes outside of the partition
output_nodes: List[Node] = field(default_factory=list)
# Parameters that are being used
params: List[str] = field(default_factory=list)
示例
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(3, 3)
self.relu = torch.nn.ReLU()
self.linear2 = torch.nn.Linear(3, 5)
def forward(self, x):
x = self.linear1(x)
x = self.linear1(x)
x = self.relu(x)
x = self.linear2(x)
return x
inputs = (torch.randn(3, 3),)
edge_graph = to_edge(export(M(), inputs)).exported_program().graph_module.graph
print(edge_graph)
"""
graph():
%arg0 : [#users=1] = placeholder[target=arg0]
%_param_constant0 : [#users=1] = get_attr[target=_param_constant0]
%permute_default : [#users=1] = call_function[target=torch.ops.aten.permute_copy.default](args = (%_param_constant0,), kwargs = {})
%_param_constant1 : [#users=1] = get_attr[target=_param_constant1]
%addmm_default : [#users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant1, %arg0, %t_default), kwargs = {})
%_param_constant0_1 : [#users=1] = get_attr[target=_param_constant0]
%permute_default_1 : [#users=1] = call_function[target=torch.ops.aten.permute_copy.default](args = (%_param_constant0_1,), kwargs = {})
%_param_constant1_1 : [#users=1] = get_attr[target=_param_constant1]
%addmm_default_1 : [#users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant1_1, %addmm_default, %t_default_1), kwargs = {})
%relu_default : [#users=1] = call_function[target=torch.ops.aten.relu.default](args = (%addmm_default_1,), kwargs = {})
%_param_constant2 : [#users=1] = get_attr[target=_param_constant2]
%permute_default_2 : [#users=1] = call_function[target=torch.ops.aten.permute_copy.default](args = (%_param_constant2,), kwargs = {})
%_param_constant3 : [#users=1] = get_attr[target=_param_constant3]
%addmm_default_2 : [#users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant3, %relu_default, %t_default_2), kwargs = {})
return [addmm_default_2]
"""
module_partitions = get_source_partitions(edge_graph, [torch.nn.Linear, torch.nn.ReLU])
print(module_partitions)
"""
{<class 'torch.nn.modules.linear.Linear'>: [
ModulePartition(nodes=[_param_constant0, t_default, _param_constant1, addmm_default], module_type=<class 'torch.nn.modules.linear.Linear'>, input_nodes=[arg0], output_nodes=[addmm_default], params=["_param_constant0", "_param_constant1"]),
ModulePartition(nodes=[_param_constant0_1, t_default_1, _param_constant1_1, addmm_default_1], module_type=<class 'torch.nn.modules.linear.Linear'>, input_nodes=[addmm_default], output_nodes=[addmm_default_1], params=["_param_constant0_1", "_param_constant1_1"]),
ModulePartition(nodes=[_param_constant2, t_default_2, _param_constant3, addmm_default_2], module_type=<class 'torch.nn.modules.linear.Linear'>, input_nodes=[relu_default], output_nodes=[addmm_default_2], params=["_param_constant2", "_param_constant3"])],
<class 'torch.nn.modules.activation.ReLU'>: [
ModulePartition(nodes=[relu_default], module_type=<class 'torch.nn.modules.activation.ReLU'>, input_nodes=[addmm_default_1], output_nodes=[relu_default], params=[])]}
"""