在 ATen IR 上编写图变换¶
传递¶
由于 ATen IR 位于 FX 图/图模块级别,因此为 FX 图编写的任何转换都可以轻松应用于 ATen IR。如果您熟悉编写 FX 图转换,那么这将是相同的。
编写转换的最直接方法是遍历给定的图并直接操作图中的节点。
例如,假设我们要用 torch.ops.aten.mul.Tensor() 调用替换 torch.ops.aten.add.Tensor() 调用
import torch
def replace_add_with_mul(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
    for node in gm.graph.nodes:
        if node.op == "call_function" and node.target == torch.ops.aten.add.Tensor:
            node.target = torch.ops.aten.mul.Tensor
我们还可以通过 FX 实用函数删除和追加新节点,这些函数可以在 Graph 文档中找到。例如,如果我们想在 add 调用之后插入一个 torch.ops.aten.relu.default()
import torch
def insert_relu_after_add(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
    for node in gm.graph.nodes:
        if node.op == "call_function" and node.target == torch.ops.aten.add.Tensor:
            # Specifies the insertion point. Any nodes added to the graph within
            # this scope will be inserted after `node`
            with gm.graph.inserting_after(node):
                # Insert a new `call_function` node with op `torch.ops.aten.relu.default`
                new_relu_node = gm.graph.call_function(torch.ops.aten.relu.default, args=(node,))
                # Replace all the places that use `node` to now use the `new_relu_node`
                node.replace_all_uses_with(new_relu_node)
一般来说,转换可以大致分为几个轴
轴 A:1. 创建一对多映射(例如,分解)2. 创建多对一映射(例如,融合)
轴 B:1. 进行正向迭代(例如,形状传播)2. 进行反向迭代(例如,死代码消除)
轴 C:1. 依赖于本地节点信息(例如,out-variant 转换)2. 依赖于全局图信息(例如,内存规划)
我们对这些用例频率的预测是:1. A.1、B.1、C.1 2. A.2 3. B.2、C.2
虽然我们可以通过直接操作图来进行所有图转换,但我们也提供了一些辅助工具,以便于使用 1 级和 2 级用例。
转换器¶
对于 1 级用例(创建一对多映射、进行正向迭代以及查看本地节点信息),我们可以利用 Transformer 类来执行每个节点并重新创建图,除了指定转换之外。
一对一传递¶
对于一对一映射的示例,如果我们想用另一个操作 B 替换操作 A,我们可以运行 GraphModule,并且每次看到操作 A 时,都返回操作 B。
一个例子是
class ReplaceAddWithMul(torch.fx.Transformer):
    def call_function(self, target, args, kwargs):
        if target != torch.ops.aten.add.Tensor:
            return super().call_function(target, args, kwargs)
        return super().call_function(torch.ops.aten.mul.Tensor, args, kwargs)
transformed_graph_module = ReplaceAddWithMul(graph_module).transform()
The super().call_function(target, args, kwargs, meta) 调用创建一个 call_function FX 节点,并返回使用给定参数运行运算符的结果。
一对多传递¶
如果我们想进行一对多映射,例如用另外两个操作 B 和 C 替换操作 A,那么我们将对 super().call_function 进行两次调用,以创建两个 FX 节点,一个使用操作 B,另一个使用操作 C,并返回运行操作 C 的结果。
例如
class ReplaceAddWithMulSub(torch.fx.Transformer):
    """
    Original:
        def f(x, y):
            return x + y
    After pass:
        def f(x, y):
            z = x * y
            return z - y
    """
    def call_function(self, target, args, kwargs):
        if target != torch.ops.aten.add.Tensor:
            return super().call_function(target, args, kwargs)
        x, y = args
        mul_res = super().call_function(torch.ops.aten.mul.Tensor, args, {})
        return super().call_function(torch.ops.aten.sub.Tensor, (mul_res, y), {})
transformed_graph_module = ReplaceAddWithMulSub(graph_module).transform()
一对零传递¶
如果我们想删除一个操作,我们可以只返回传递给函数的值
class RemoveDetachPass(torch.fx.Transformer):
    def call_function(self, target, args, kwargs):
        if target not in (
            torch.ops.aten.detach.default,
            torch.ops.aten.detach_copy.default,
        ):
            return super().call_function(target, args, kwargs, meta)
        assert len(args) == 1
        return args[0]
transformed_graph_module = RemoveDetachPass(graph_module).transform()
利用本地信息¶
利用本地节点信息的示例是,如果我们想将图中的所有标量转换为张量,我们可以运行给定的 fx.GraphModule,并且对于包含标量的每个参数,我们将其转换为张量。它可能看起来像这样
def args_map(target, fn, args, kwargs):
    assert isinstance(args, tuple)
    assert isinstance(kwargs, dict)
    args = list(args)
    kwargs = kwargs.copy()
    # Update the argument based on the function passed
    def update(key, args, schema):
        args[key] = fn(args[key], schema)
    # Update each argument in the schema
    for i, schema in enumerate(target._schema.arguments):
        if schema.name in kwargs:
            update(schema.name, kwargs, schema)
        elif not schema.kwarg_only and i < len(args):
            update(i, args, schema)
    return tuple(args), kwargs
class ScalarToTensorPass(torch.fx.Transformer):
    def call_function(self, target, args, kwargs):
        breakpoint()
        def try_coerce(value, arg):
            return (
                torch.tensor(value)
                if isinstance(value, (float, int, bool))
                and type(arg.type) == torch.TensorType
                else value
            )
        args, kwargs = args_map(target, try_coerce, args, kwargs)
        return super().call_function(target, args, kwargs)
transformed_graph_module = ScalarToTensorPass(graph_module).transform()
子图重写器¶
为了创建多对一映射,我们可以利用 FX 的 子图重写器。给定一个 pattern,它会创建一个与模式匹配的运算符子图,然后用 replacement 替换每个匹配的子图。
注意
This is an inplace operation.
pattern 和 replacement 输入必须是可调用函数或包含与图中使用的相同运算符(ATen 运算符)的 GraphModules,以便子图重写器可以在图中找到正确的模式。模式/替换可调用函数的输入在匹配时将被视为通配符。
示例
from torch.fx import subgraph_rewriter
def replace_patterns(graph_module):
    def pattern(x, y):
        x = torch.ops.aten.add.Tensor(x, y)
        x = torch.ops.aten.mul.Tensor(x, y)
        return x
    def replacement(x, y):
        return torch.ops.aten.sub.Tensor(x, y)
replaced_patterns = subgraph_rewriter.replace_pattern_with_filters(
    traced_module, pattern, replacement
)
子图重写器返回一个 ReplacedPatterns 列表
@dataclass
class ReplacedPatterns:
    # Node from which the match was found
    anchor: Node
    # Maps nodes in the pattern subgraph to nodes in the larger graph
    nodes_map: Dict[Node, Node]
    # List of nodes that were added into the graph
    replacements: List[Node]
注意
The nodes created by the subgraph rewriter will not have the metadata that
is populated in the matched nodes, but you can use
`ReplacedPatterns.nodes_map` to find the nodes in the original graph that
were matched, and `ReplacedPatterns.replacements` to find the nodes that
were replaced in the transformed graph.
Pass Manager¶
`PassManager` <https://github.com/pytorch/pytorch/blob/main/torch/fx/passes/infra/pass_manager.py>`__ 是一个用于对给定图模块运行多个传递的类。在初始化 PassManager 实例时,我们传入要运行的传递列表并设置一些标志。要对图模块运行传递集合,我们可以将图模块直接传递给 PassManager 实例。
示例
from torch.fx.passes.infra.pass_manager import PassManager
pm = PassManager(
    passes=[replace_add_with_div, replace_div_with_mul],
    run_checks_after_each_pass=True,
    suppress_check_failures=False,
)
graph_module_out = pm(graph_module)
要添加在每次传递后运行的一组通用检查,我们可以调用函数 set_checks(check: Callable),该函数接受一个可调用函数作为输入。如果 run_checks_after_each_pass 标志已设置,则 check 将在对图模块运行每次传递后被调用。
示例
pm = PassManager(passes=[replace_add_with_div, replace_div_with_mul])
def check_div_target(graph_module):
    for node in graph_module.graph.nodes:
        if node.op == "call_function" and node.target != torch.div:
            raise ValueError("Target should be div!")
pm.add_checks(check_div_target)
pm(graph_module)    # raises ValueError after replace_div_with_mul pass
Partitioner¶
我们可以使用一些常见的基于 FX 图的划分器来划分图。
子图匹配器¶
为了在与特定模式匹配的图中找到子图,我们可以利用 FX 的 `SubgraphMatcher` <https://github.com/pytorch/pytorch/blob/main/torch/fx/passes/utils/matcher_utils.py>`__。
类属性
- pattern (Graph): 目标匹配模式。图中的占位符节点在匹配时将被视为通配符。
- match_output (bool): 如果为 True,则模式图中的输出节点将被视为目标模式的一部分。如果为 False,则在匹配期间忽略输出节点。
- match_placeholder (bool): 如果为 True,则模式图中的占位符节点将被视为目标模式的一部分。如果为 False,则占位符节点将用作通配符。
- remove_overlapping_matches (bool): 如果为 True,在出现重叠匹配的情况下,只返回第一个匹配项。
- ignore_literals (bool): 如果为 True,将不检查文字是否相等,而是将它们视为通配符。
示例
from torch.fx.passes.utils.matcher_utils import SubgraphMatcher
class LargeModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self._weight = torch.nn.Parameter(torch.ones(3, 3))
        self._bias = torch.nn.Parameter(torch.ones(3, 3))
    def forward(self, x):
        return torch.ops.aten.addmm.default(self._bias, x, self._weight)
large_model_graph = torch.export(LargeModel(), inputs).graph
class PatternModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self._weight_1 = torch.nn.Parameter(torch.ones(5, 5))
        self._bias_1 = torch.nn.Parameter(torch.ones(5, 5))
    def forward(self, x):
        return torch.ops.aten.addmm.default(self._bias_1, x, self._weight_1)
pattern_graph = torch.export(PatternModel(), inputs).graph
subgraph_matcher = SubgraphMatcher(pattern_graph)
match_result = subgraph_matcher.match(large_model_graph)
The match function returns a list of InternalMatch
@dataclass
class InternalMatch():
    # Nodes from which the match was found
    anchors: List[Node]
    # Maps nodes in the pattern subgraph to nodes in the larger graph
    nodes_map: Dict[Node, Node] = field(default_factory=dict)
    # Nodes in target graph that are matched placeholder in pattern
    placeholder_nodes: List[Node] = field(default_factory=list)
    # Nodes in matched subgraph returned by output
    returning_nodes: List[Node] = field(default_factory=list)
基于能力的划分器¶
为了找到支持特定不变式的最大节点子图,我们可以利用 FX 的 `CapabilityBasedPartitioner` <https://github.com/pytorch/pytorch/blob/main/torch/fx/passes/infra/partitioner.py#L34>`__。
类属性
- graph_module (torch.fx.GraphModule): 我们要进行划分的图模块。
- operator_support (OperatorSupportBase): 用于确定图中节点是否在分区中受支持的对象。
- allows_single_node_partition (bool): 如果为 True,则允许形成单个节点分区。
- non_compute_ops (Optional[Sequence[str]]): 一组被认为是“非计算”的操作(例如- torch.ops.aten.view和- _operator.getitem,以便分区器不会创建仅包含这些非计算操作的图
- allowed_single_node_partition_ops (Optional[Sequence[str]]): 一组允许在单个节点分区中使用的操作。
__`OperatorSupportBase` <https://github.com/pytorch/pytorch/blob/main/torch/fx/passes/operator_support.py#LL28C1-L28C1>__ 类由分区器使用,以确定图中的特定节点是否属于分区。 这是通过覆盖 is_node_supported 函数来实现的。您可以通过使用 `chain` <https://github.com/pytorch/pytorch/blob/main/torch/fx/passes/operator_support.py#L150>__(如果任何 OperatorSupportBase 返回 False 则返回 False)和 `any_chain` <https://github.com/pytorch/pytorch/blob/main/torch/fx/passes/operator_support.py#L164>__(如果任何 OperatorSupportBase 返回 True 则返回 True)来链接多个 OperatorSupportBase。
示例
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
from torch.fx.passes.operator_support import any_chain, OperatorSupportBase
class AddMulOperatorSupport(OperatorSupportBase):
    def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
        return node.op == "call_function" and node.target in [
            torch.ops.aten.add.Tensor, torch.ops.aten.mul.Tensor,
        ]
capability_partitioner = CapabilityBasedPartitioner(
    graph_module,
    op_support,
)
# Returns a list of partitions (list of nodes that belong in each partition)
partition_list = capability_partitioner.propose_partitions()
# Fuses the partitions into graph modules and inserts `call_module` nodes in the graph
fused_graph_module = capability_partitioner.fuse_partitions(partition_list)