快捷方式

后端方言

概述

后端方言边缘方言 的一种特殊变体,因为它包含后端特定的节点和元数据,这些节点和元数据是在后端特定的图转换之后添加的。后端方言是一个可选阶段,只有在我们想要将后端感知引入图时才需要。更具体地说,后端方言中的图可能包含仅对目标后端有意义的操作符或委托降低的模块(参见 委托文档)。一个用例是,如果我们想将操作符融合成单个操作符,例如,将连续的 addmm + relu 融合成单个操作符 addmm_relu,我们可以在此处进行操作。

本文档描述了如何引入后端特定的操作符。

自定义操作符和后端特定操作符之间的区别:自定义操作符出现在急切模式、ATen 方言和边缘方言中,而后端特定操作符仅由边缘方言之后发生的传递引入。

何时使用

这种方言允许引入不符合规范 ATen 操作符集中定义的模式,并且不在上述任何方言(ATen 方言和边缘方言)中出现的操作符。如果您的用例满足以下一个或多个标准,请考虑使用后端操作符

  • 您的后端提供一个库,该库优化了等效于子图的特定操作符。例如,linear_relu(等效于线性 + relu)可以在特定后端上更快地执行。

  • 需要在图模块降低到后端之后重新跟踪它。当我们重新跟踪时,后端操作符可以转换回原始子图(在 ATen 方言中),而普通自定义操作符不处理这种情况。

  • 您的后端特定操作符没有通用的 CPU 内核,而只有一个特定后端的内核。使用后端操作符可以通过使用原始子图作为默认内核并保持图模块可运行来解决此问题。

  • 或者,如果您担心这可能太过复杂,并且只想使用更轻量级的东西,并且只需要在编译阶段使用 Python 代码,则可以使用委托。

API

对于操作符/子图替换,常见的流程是

  1. 注册一个与子图具有相同输入和输出的操作符。此操作符不会具有特定于目标的实现(同样,在编译阶段也不需要),但它需要给出与子图相同的结果。

  2. 创建一个模式,允许编译器找到子图并用替换项替换它。

  3. 编写一个传递来用新操作符替换子图。

为了促进此过程,我们提供了一个 API 来帮助减少 ExecuTorch 用户执行这些步骤的工作量。

传递基础设施入口点

为了将边缘操作降低到后端操作,一个传递将执行模式匹配以识别图中感兴趣的边缘操作,然后用等效的后端操作替换它们。有两种 API 可以注册此类传递

  • transform()。ExportProgram 上的 API,允许用户提供自定义传递。请注意,这不受任何验证器的保护,因此无法保证程序的健全性。

  • ExecutorchBackendConfig.passes。如果在此处添加,则传递将成为从后端方言到 ExecutorchProgram 的降低过程的一部分。

示例:QuantFusion 就是这样一个传递。此传递采用“规范量化模式”,即“dequant - some_op - quant”,并将此模式融合成一个后端特定的单个操作符,即 quantized_decomposed::some_op。另一个更简单的示例是 这里,我们用 ExecuTorch 理解的操作符替换 sym_size 操作符

模式绑定装饰器

我们提供了一个装饰器 bind_pattern_to_op 来帮助用户轻松地将他们的后端操作符注册到 EXIR 中。此装饰器采用

  • 一个 torch.Library 对象,它指示此后端操作符属于哪个库或命名空间。

  • 一个名称或模式。如果我们已经在 torch.Library 对象中定义了后端操作符的模式,则只需要一个名称。否则,如果传递了模式字符串,我们可以注册模式。

此装饰器应添加到我们尝试匹配的模式(然后降低到此后端操作符)上边缘方言。这样,我们将此模式注册为此后端操作符的 CompositeImplicitAutograd 内核。

然后,操作符可以在传递中被访问/使用。 CompositeImplicitAutograd 内核确保

  1. 用户无需编写(CPU)可运行内核。

  2. 确保 ExportProgram 的可回溯性。一旦回溯,后端操作符将被分解为模式中使用的 ATen 操作。

示例

假设一个包含加法和 ReLU 操作符的简单程序

def f(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    z = x + y
    return torch.ops.aten.relu.default(z)

降低到边缘方言后,它变成

graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
    %aten_add_tensor : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%arg0_1, %arg1_1), kwargs = {})
    %aten_relu_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.relu.default](args = (%aten_add_tensor,), kwargs = {})
    return (aten_relu_default,)

现在我想编写一个传递来将 addrelu 合并到 add_relu 中,第一步是编写一个模式

# In the pattern, we can use edge ops and ATen ops interchangably
def pattern(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    z = torch.ops.aten.add.Tensor(x, y)
    out = torch.ops.aten.relu.default(z)
    return out

然后我们需要从融合操作符命名空间创建一个操作符库,然后在我们的模式上使用装饰器

lib = Library("foo_namespace", "DEF")

@bind_pattern_to_op(lib, "add_relu(Tensor self, Tensor other) -> Tensor")
def pattern(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        z = torch.ops.aten.add.Tensor(x, y)
        out = torch.ops.aten.relu.default(z)
        return out

这样,我们将模式注册为 add_relu 的内核,它已准备好用于传递。一个简单的传递如下所示

class AddReluFusionPass(ExportPass):
    def call(self, graph_module: GraphModule) -> PassResult:
        # decorator registers this pattern as a CompositeExplicitAutograd kernel, since there's no kernel registered before.
        @bind_pattern_to_op(lib, "add_relu")
        def pattern(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
            z = torch.ops.aten.add.Tensor(x, y)
            out = torch.ops.aten.relu.default(z)
            return out

        def replacement(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
            return torch.ops.foo_namespace.add_relu.default(x, y)

        subgraph_rewriter.replace_pattern(
            graph_module,
            _trace_and_lower_to_edge_ops(pattern),
            _trace_and_lower_to_edge_ops(replacement),
        )
        return PassResult(graph_module, True)

结果图如下所示

graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
    %foo_namespace_add_relu_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.foo_namespace.add_relu.default](args = (%arg0_1, %arg1_1), kwargs = {})
    return (foo_namespace_add_relu_default,)

操作符集

以下是当前使用 bind_pattern_to_op API 的后端操作符。

  • executorch_prims::add.int(SymInt a, SymInt b) -> SymInt

    • 模式:builtin.add

    • 后端:执行器

  • executorch_prims::mul.int(SymInt a, SymInt b) -> SymInt

    • 模式:builtin.mul

    • 后端:执行器

  • executorch_prims::sub.int(SymInt a, SymInt b) -> SymInt

    • 模式:builtin.sub

    • 后端:执行器

  • executorch_prims::floordiv.int(SymInt a, SymInt b) -> SymInt

    • 模式:builtin.floordiv

    • 后端:执行器

  • executorch_prims::gt.int(SymInt a, SymInt b) -> bool

    • 模式:builtin.gt

    • 后端:执行器

  • executorch_prims::lt.int(SymInt a, SymInt b) -> bool

    • pattern: builtin.lt

    • 后端:执行器

  • executorch_prims::ge.int(SymInt a, SymInt b) -> bool

    • pattern: builtin.ge

    • 后端:执行器

  • executorch_prims::le.int(SymInt a, SymInt b) -> bool

    • pattern: builtin.le

    • 后端:执行器

  • executorch_prims::eq.int(SymInt a, SymInt b) -> bool

    • pattern: builtin.eq

    • 后端:执行器

  • quantized_decomposed::embedding_byte(Tensor weight, Tensor weight_scales, Tensor weight_zero_points, int weight_quant_min, int weight_quant_max, Tensor indices) -> Tensor

    • pattern: source

    • backend: quantization

  • quantized_decomposed::add(Tensor a, float a_scale, int a_zero_point, int a_quant_min, int a_quant_max, Tensor b, float b_scale, int b_zero_point, int b_quant_min, int b_quant_max, float out_scale, int out_zero_point, int out_quant_min, int out_quant_max) -> Tensor qc

    • pattern: source

    • backend: quantization

  • quantized_decomposed::add.scalar(Tensor qa, float a_scale, int a_zero_point, int a_quant_min, int a_quant_max, ScalarType a_dtype, Scalar b, float out_scale, int out_zero_point, int out_quant_min, int out_quant_max, ScalarType out_dtype) -> Tensor

    • pattern: source

    • backend: quantization

  • quantized_decomposed::add_relu(Tensor a, float a_scale, int a_zero_point, int a_quant_min, int a_quant_max, Tensor b, float b_scale, int b_zero_point, int b_quant_min, int b_quant_max, float out_scale, int out_zero_point, int out_quant_min, int out_quant_max) -> Tensor qc

    • pattern: source

    • backend: quantization

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源