• 文档 >
  • 自定义编译器传递和分区器
快捷方式

自定义编译器传递和分区器

传递

传递可以大致分为几个方面

方面 A

  1. 创建一对多映射(例如,分解)

  2. 创建多对一映射(例如,融合)

方面 B

  1. 执行正向迭代(例如,形状传播)

  2. 执行反向迭代(例如,死代码消除)

方面 C

  1. 依赖于局部节点信息(例如,输出变体转换)

  2. 依赖于全局图信息(例如,内存规划)

我们对这些用例频率的预测是

  1. A.1, B.1, C.1

  2. A.2

  3. B.2, C.2

级别 1

对于级别 1 用例(创建一对多映射,执行正向迭代以及查看局部节点信息),我们可以使用一个名为 ExportPass 的辅助类。这是一种 基于解释器 的方法,我们执行每个节点并重新创建图形,但包含指定的转换。这使我们能够通过确保所有在传递过程中创建的节点都满足 IR 规范(包括确保元数据(如堆栈跟踪、FakeTensor 值和 torch.nn.Module 层次结构)得到保留和更新(取决于所做的转换)来保留 IR 规范。

要实现此传递,我们可以创建 ExportPass 的子类并实现公开的函数。当用图形模块调用时,它将运行图形模块并创建一个包含指定传递更改的新图形。这意味着传递的图形模块必须在 CPU 上可运行,并且在运行传递后将保持这种不变性。

一对一传递

对于一对一映射的示例,如果我们要将操作 A 替换为另一个操作 B,我们可以运行给定的 fx.GraphModule,并且每次看到操作 A 时,返回操作 B。

考虑以下示例

class ReplaceInPlaceReluWithOutOfPlaceReluPass(ExportPass):
    """
    relu_ is the in-place version. Replace it with relu, which is the
    out-of-place version
    """

    def call_operator(self, op, args, kwargs, meta):
        if op != torch.ops.aten.relu_.default:
            return super().call_operator(op, args, kwargs, meta)
        return super().call_operator(Op(torch.ops.aten.relu.default), args, kwargs, meta)

# To create a pass
replace_pass = ReplaceInPlaceReluWithOutOfPlaceReluPass()
# To run a pass
new_graph_module = replace_pass(graph_module).graph_module

super().call_operator(op, args, kwargs, meta) 调用创建一个 call_function FX 节点,并返回使用给定参数运行运算符的结果。

一对多传递

如果我们要执行一对多映射,例如将操作 A 替换为另外 2 个操作 B 和 C,那么我们将调用 super().call_operator 两次以创建 2 个 FX 节点,一个使用操作 B,另一个使用操作 C,并返回运行操作 C 的结果。

例如

class ReplaceAddWithMulSub(ExportPass):
    """
    Original:
        def f(x, y):
            return x + y

    After pass:
        def f(x, y):
            z = x * y
            return z - y
    """
    def call_operator(self, op, args, kwargs, meta):
        if op != torch.ops.aten.add.default:
            return super().call_operator(op, args, kwargs, meta)

        x, y = args

        mul_res = super().call_operator(
            torch.ops.aten.mul.default,
            args,
            {},
            meta
        )

        return super().call_operator(
            torch.ops.aten.sub.default,
            (mul_res, y),
            {},
            meta
        )

一对零传递

如果我们要删除一个操作,我们可以只返回传递到函数中的值

class RemoveDetachPass(ExportPass):
    def call_operator(self, op, args, kwargs, meta):
        if op not in (
            torch.ops.aten.detach.default,
            torch.ops.aten.detach_copy.default,
        ):
            return super().call_operator(op, args, kwargs, meta)

        assert len(args) == 1
        return args[0]

利用局部信息

利用局部节点信息的示例是,如果我们要将图形中的所有标量转换为张量,我们可以运行给定的 fx.GraphModule,并且对于包含标量的每个参数,我们将它转换为张量。它可能类似于

def args_map(op, fn, args, kwargs):
    assert isinstance(args, tuple)
    assert isinstance(kwargs, dict)
    args = list(args)
    kwargs = kwargs.copy()

    # Update the argument based on the function passed
    def update(key, args, schema):
        args[key] = fn(args[key], schema)

    # Update each argument in the schema
    for i, schema in enumerate(self.op._schema.arguments):
        if schema.name in kwargs:
            update(schema.name, kwargs, schema)
        elif not schema.kwarg_only and i < len(args):
            update(i, args, schema)

class ScalarToTensorPass(ExportPass):
    def call_operator(self, op, args, kwargs):
        def try_coerce(value, arg):
            return (
                torch.tensor(value)
                if isinstance(value, (float, int, bool))
                and type(arg.type) == torch.TensorType
                else value
            )

        args, kwargs = args_map(op, try_coerce, args, kwargs)
        return super().call_operator(op, args, kwargs)

级别 2

对于创建多对一映射,我们可以利用 FX 的 子图重写器。给定一个 pattern,它将创建一个与模式匹配的运算符子图,然后将每个匹配的子图替换为 replacement

注意

This is an inplace operation.

patternreplacement 输入必须是使用与您要匹配的 EXIR 图中相同的运算符(ATen 运算符)编写的可调用函数,以便子图重写器可以在图形中找到正确的模式。传递到模式/替换可调用的输入将被视为通配符。

考虑以下示例

from torch.fx import subgraph_rewriter

def replace_patterns(graph_module):
    def pattern(x, y):
        x = torch.ops.aten.add.Tensor(x, y)
        x = torch.ops.aten.mul.Tensor(x, y)
        return x

    def replacement(x, y):
        return torch.ops.aten.sub.Tensor(x, y)

replaced_patterns = subgraph_rewriter.replace_pattern_with_filters(
    traced_module, pattern, replacement
)

子图重写器返回一个 ReplacedPatterns 列表

@dataclass
class ReplacedPatterns:
    # Node from which the match was found
    anchor: Node
    # Maps nodes in the pattern subgraph to nodes in the larger graph
    nodes_map: Dict[Node, Node]
    # List of nodes that were added into the graph
    replacements: List[Node]

注意

The nodes created by the subgraph rewriter will not have the metadata that
is normally in EXIR nodes (`stack_trace`, `val`, `nn_module_stack`).

级别 3

对于创建传递的第三种方法,我们可以利用最基本的 PassBase。要创建传递,我们可以为此创建子类并使用传递内容实现函数 call。此外,我们还可以实现函数 requiresensures,它们将在函数 call 之前和之后调用。请注意,这些函数也可以在 ExportPass 中覆盖。要对图形模块运行传递,我们可以将图形模块直接传递给类的实例。

考虑以下示例

class ReplaceAddPass(PassBase):

    def __init__(self, replace_op):
        self.replace_op = replace_op

    def call(self, graph_module):
        for node in gm.graph.nodes:
            if node.op == "call_function" and node.target == torch.add:
                node.target = self.replace_op

    # Optional to implement, will be called before call()
    def requires(self, graph_module) -> None:
        for node in graph_module.graph.nodes:
            if node.op == "call_function" and node.target == torch.add:
                return
        raise ValueError("No torch.add ops!")

    # Optional to implement, will be called after call()
    def ensures(self, graph_module: torch.fx.GraphModule) -> None:
        pass

# To create a pass
replace_add_with_div = ReplaceAddPass(torch.div)
# To run a pass
replace_add_with_div(graph_module)

传递管理器

PassManager 是一个用于在给定图模块上运行多个传递的类。在初始化 PassManager 实例时,我们会传入一个要运行的传递列表并设置几个标志。要对图模块运行传递集合,我们可以将图模块直接传递给 PassManager 实例。

示例

from executorch.exir.pass_manager import PassManager

pm = PassManager(
    passes=[replace_add_with_div, replace_div_with_mul],
    run_checks_after_each_pass=True,
    suppress_check_failures=False,
)
graph_module_out = pm(graph_module)

要添加在每次传递后运行的常见检查集,我们可以调用函数 set_checks(check: Callable),该函数将可调用函数作为输入。如果 run_checks_after_each_pass 标志已设置,则 check 将在每次传递在图模块上运行后被调用。

示例

pm = PassManager(passes=[replace_add_with_div, replace_div_with_mul])

def check_div_target(graph_module):
    for node in graph_module.graph.nodes:
        if node.op == "call_function" and node.target != torch.div:
            raise ValueError("Target should be div!")

pm.add_checks(check_div_target)

pm(graph_module)    # raises ValueError after replace_div_with_mul pass

分区器

我们可以使用几个基于 FX 图的常见分区器来对图进行分区。但是,这些并不一定能生成符合 IR 规范的图,因此在使用它们时要小心。

子图匹配器

要查找与特定模式匹配的图中的子图,我们可以利用 FX 的 SubgraphMatcher

类属性

  • pattern (Graph):目标匹配模式。图中的占位符节点将在匹配时被视为通配符。

  • match_output (bool):如果为 True,则模式图中的输出节点将被视为目标模式的一部分。如果为 False,则在匹配期间忽略输出节点。

  • match_placeholder (bool):如果为 True,则模式图中的占位符节点将被视为目标模式的一部分。如果为 False,则占位符节点将用作通配符。

  • remove_overlapping_matches (bool):如果为 True,则在出现重叠匹配的情况下,只返回第一个匹配项。

  • ignore_literals (bool):如果为 True,则不会检查字面量是否相等,而是将它们视为通配符。

考虑以下示例

from torch.fx.passes.utils.matcher_utils import SubgraphMatcher

class LargeModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self._weight = torch.nn.Parameter(torch.ones(3, 3))
        self._bias = torch.nn.Parameter(torch.ones(3, 3))

    def forward(self, x):
        return torch.ops.aten.addmm.default(self._bias, x, self._weight)

large_model_graph = to_edge(export(LargeModel(), large_inputs)).exported_program().graph_module.graph

class PatternModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self._weight_1 = torch.nn.Parameter(torch.ones(5, 5))
        self._bias_1 = torch.nn.Parameter(torch.ones(5, 5))

    def forward(self, x):
        return torch.ops.aten.addmm.default(self._bias_1, x, self._weight_1)

pattern_graph = to_edge(export(PatternModel(), pattern_inputs)).exported_program().graph_module.graph

subgraph_matcher = SubgraphMatcher(pattern_graph)
match_result = subgraph_matcher.match(large_model_graph)

match 函数返回一个 InternalMatch 列表

@dataclass
class InternalMatch():
    # Nodes from which the match was found
    anchors: List[Node]
    # Maps nodes in the pattern subgraph to nodes in the larger graph
    nodes_map: Dict[Node, Node] = field(default_factory=dict)
    # Nodes in target graph that are matched placeholder in pattern
    placeholder_nodes: List[Node] = field(default_factory=list)
    # Nodes in matched subgraph returned by output
    returning_nodes: List[Node] = field(default_factory=list)

基于能力的分区器

要查找支持特定不变式的最大节点子图,我们可以利用 FX 的 CapabilityBasedPartitioner

类属性

  • graph_module (torch.fx.GraphModule):我们正在对其进行分区的图模块。

  • operator_support (OperatorSupportBase):用于确定图中节点是否在分区中支持的对象。

  • allows_single_node_partition (bool):如果为 True,则允许形成单个节点分区。

  • non_compute_ops (Optional[Sequence[str]]):一组被认为是“非计算”的操作(例如 torch.ops.aten.view_operator.getitem,这样分区器就不会创建只包含这些非计算操作的图

  • allowed_single_node_partition_ops (Optional[Sequence[str]]):一组允许在单个节点分区中使用的操作。

OperatorSupportBase 类由分区器用于确定图中的特定节点是否属于该分区。这是通过覆盖 is_node_supported 函数来完成的。您可以通过使用 chain(如果任何 OperatorSupportBase 返回 False,则返回 False)和 any_chain(如果任何 OperatorSupportBase 返回 True,则返回 True)来链接多个 OperatorSuppportBase

考虑以下示例

from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
from torch.fx.passes.operator_support import any_chain, OperatorSupportBase

class AddMulOperatorSupport(OperatorSupportBase):
    def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
        return node.op == "call_function" and node.target in [
            torch.ops.aten.add.Tensor, torch.ops.aten.mul.Tensor,
        ]

capability_partitioner = CapabilityBasedPartitioner(
    graph_module,
    op_support,
)

# Returns a list of partitions (list of nodes that belong in each partition)
partition_list = capability_partitioner.propose_partitions()

如果您查看基于能力的分区器,您可能还会发现一个 fuse_partition 函数,该函数将返回一个修改后的图,其中分区作为子模块,并通过 call_module 节点在顶层图中调用这些子模块。但是,这与 IR 规范不兼容,因为我们不允许 call_module 节点。

组合

我们还提供一个组合的辅助函数:generate_pattern_op_partitions

参数

  • graph_module (fx.GraphModule):我们要分区的模块

  • patterns (List[torch.fx.Graph]):一个以 torch.fx.Graph 形式存在的模式列表。这些图可以通过 exir.capture(推荐)或符号跟踪(这可能不会导致准确的边缘方言图)获得的 GraphModule 中的 graph 字段获得,或者通过手动制作一个图模块获得。

  • op_support (OperatorSupportBase):一个 OperatorSupportBase,可以通过以下方式创建

    • 直接子类化并实现 is_node_supported()

    • 获取 create_op_support() 的结果

    • 获取 create_pattern_support() 的结果

    • 使用 chain()any_chain() 链接在一起的多个 OperatorSupportBase 类

返回值

  • 一个分区列表(最大可能的子图),其中包含由给定 OperatorSupportBase 对象和给定模式图的并集支持的节点。

源分区器

对于更复杂的使用案例,其中用户希望根据更高级别的模块(torch.nn.Lineartorch.nn.functional.Linear)进行分区,这些模块现在已分解为其操作符(aten.permuteaten.addmm),我们有以下 辅助函数

get_source_partitions(graph: torch.fx.Graph, wanted_sources: List[Any]) -> Dict[Any, SourcePartition]

参数

  • graph:我们要分区的图

  • wanted_sources:从该源分解而来的节点的源列表。这可以是一个函数(例如 torch.nn.functional.linear)或一个叶模块类型(例如 torch.nn.Linear

返回值

  • 将源(例如 torch.nn.modules.linear.Linear)映射到与从该类型模块展开的节点列表相对应的 SourcePartitions 列表的字典。

@dataclass
class SourcePartition():
    # Nodes in a particular partition
    nodes: List[Node]
    # Module type
    module_type: Type
    # Nodes in the graph that are needed as inputs to the partition
    input_nodes: List[Node] = field(default_factory=list)
    # Nodes in the partition that are being used by nodes outside of the partition
    output_nodes: List[Node] = field(default_factory=list)
    # Parameters that are being used
    params: List[str] = field(default_factory=list)

示例

class M(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = torch.nn.Linear(3, 3)
        self.relu = torch.nn.ReLU()
        self.linear2 = torch.nn.Linear(3, 5)

    def forward(self, x):
        x = self.linear1(x)
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)
        return x

inputs = (torch.randn(3, 3),)
edge_graph = to_edge(export(M(), inputs)).exported_program().graph_module.graph
print(edge_graph)
"""
graph():
    %arg0 : [#users=1] = placeholder[target=arg0]
    %_param_constant0 : [#users=1] = get_attr[target=_param_constant0]
    %permute_default : [#users=1] = call_function[target=torch.ops.aten.permute_copy.default](args = (%_param_constant0,), kwargs = {})
    %_param_constant1 : [#users=1] = get_attr[target=_param_constant1]
    %addmm_default : [#users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant1, %arg0, %t_default), kwargs = {})
    %_param_constant0_1 : [#users=1] = get_attr[target=_param_constant0]
    %permute_default_1 : [#users=1] = call_function[target=torch.ops.aten.permute_copy.default](args = (%_param_constant0_1,), kwargs = {})
    %_param_constant1_1 : [#users=1] = get_attr[target=_param_constant1]
    %addmm_default_1 : [#users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant1_1, %addmm_default, %t_default_1), kwargs = {})
    %relu_default : [#users=1] = call_function[target=torch.ops.aten.relu.default](args = (%addmm_default_1,), kwargs = {})
    %_param_constant2 : [#users=1] = get_attr[target=_param_constant2]
    %permute_default_2 : [#users=1] = call_function[target=torch.ops.aten.permute_copy.default](args = (%_param_constant2,), kwargs = {})
    %_param_constant3 : [#users=1] = get_attr[target=_param_constant3]
    %addmm_default_2 : [#users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant3, %relu_default, %t_default_2), kwargs = {})
    return [addmm_default_2]
"""

module_partitions = get_source_partitions(edge_graph, [torch.nn.Linear, torch.nn.ReLU])
print(module_partitions)
"""
{<class 'torch.nn.modules.linear.Linear'>: [
    ModulePartition(nodes=[_param_constant0, t_default, _param_constant1, addmm_default], module_type=<class 'torch.nn.modules.linear.Linear'>, input_nodes=[arg0], output_nodes=[addmm_default], params=["_param_constant0", "_param_constant1"]),
    ModulePartition(nodes=[_param_constant0_1, t_default_1, _param_constant1_1, addmm_default_1], module_type=<class 'torch.nn.modules.linear.Linear'>, input_nodes=[addmm_default], output_nodes=[addmm_default_1], params=["_param_constant0_1", "_param_constant1_1"]),
    ModulePartition(nodes=[_param_constant2, t_default_2, _param_constant3, addmm_default_2], module_type=<class 'torch.nn.modules.linear.Linear'>, input_nodes=[relu_default], output_nodes=[addmm_default_2], params=["_param_constant2", "_param_constant3"])],

 <class 'torch.nn.modules.activation.ReLU'>: [
    ModulePartition(nodes=[relu_default], module_type=<class 'torch.nn.modules.activation.ReLU'>, input_nodes=[addmm_default_1], output_nodes=[relu_default], params=[])]}
"""

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源