快捷方式

后端方言

概览

后端方言是 edge dialect 的一种特殊变体,因为它在后端特定的图转换后,包含后端特定的节点和元数据。后端方言是一个可选阶段,只有当我们想在图中引入后端感知(backend-awareness)时才需要。更具体地说,后端方言中的图可能包含仅对目标后端有意义的 operator 或委托的下沉模块(参见 delegate 文档)。一个用例是,如果我们想将 operator 融合到一个 operator 中,例如,将连续的 addmm + relu 融合到单个 operator addmm_relu 中,我们可以在此处进行。

本文档描述了如何引入后端特定的 operator。

自定义 operator 和后端特定 operator 的区别:自定义 operator 会出现在 eager 模式、ATen 方言和 edge 方言中,而后端特定 operator 仅由发生在 edge 方言之后的 pass 引入。

何时使用

此方言允许引入不符合 canonical ATen operator set 中定义的 schema 且不出现在上述任何方言(ATen 方言和 edge 方言)中的 operator。如果您的用例满足以下一个或多个标准,请考虑使用后端 operator

  • 您的后端提供了一个库,该库优化了等同于子图的某个 operator。例如,linear_relu(等同于 linear + relu),可以在特定后端上执行得更快。

  • 将图模块下沉到后端后,需要重新追踪(retrace)它。当我们重新追踪时,后端 operator 可以转换回原始子图(在 ATen 方言中),而普通的自定义 operator 不会处理这种情况。

  • 您的后端特定 operator 没有通用的 CPU kernel,只有针对特定后端的 kernel。使用后端 operator 可以通过使用原始子图作为默认 kernel 并保持图模块可运行来解决此问题。

  • 或者,如果您担心这可能过于复杂,并且只想要更轻量级且仅在编译器阶段需要 Python 代码的东西,则可以使用 delegate。

API

对于 operator/子图替换,通用流程是

  1. 注册一个与子图具有相同输入和输出的 operator。这个 operator 不需要具有目标特定的实现(在编译阶段也不需要),但它需要提供与子图相同的结果。

  2. 创建一个 pattern,使编译器能够找到该子图并将其替换为替代项。

  3. 编写一个 pass,用新的 operator 替换子图。

为了方便此过程,我们提供了一个 API 来帮助 ExecuTorch 用户减少执行这些步骤的工作量。

Pass 基础设施入口点

要将 edge operator 下沉到 backend operator,pass 将执行 pattern 匹配,识别图中的目标 edge operator,然后用等效的 backend operator 替换它们。有两种 API 可用于注册此类 pass

  • transform()。这是 ExportProgram 上的一个 API,允许用户提供自定义 pass。请注意,此 API 不受任何 validator 的保护,因此无法保证程序的正确性(soundness)。

  • ExecutorchBackendConfig.passes。如果在此处添加,该 pass 将成为从 backend 方言到 ExecutorchProgram 的下沉过程的一部分。

示例:一个这样的 pass 是 QuantFusion。这个 pass 接收一个“规范量化 pattern”,即“dequant - some_op - quant”,并将这个 pattern 融合到一个后端特定的 operator 中,例如 quantized_decomposed::some_op。另一个更简单的示例在此处:here,我们将 sym_size operator 替换为 ExecuTorch 可以理解的 operator。

Pattern 绑定 Decorator

我们提供了一个 decorator bind_pattern_to_op 来帮助用户轻松地将其后端 operator 注册到 EXIR 中。此 decorator 接受

  • 一个 torch.Library 对象,它指示此后端 operator 属于哪个 library 或 namespace。

  • 一个名称或 schema。如果已经在 torch.Library 对象中定义了后端 operator 的 schema,则只需要提供一个名称。否则,如果传入 schema 字符串,我们可以注册该 schema。

此 decorator 应添加到我们在 edge 方言中尝试匹配(然后下沉到此后端 operator)的 pattern 上。通过这种方式,我们将此 pattern 注册为该后端 operator 的 CompositeImplicitAutograd kernel。

然后可以从 pass 中访问/使用该 operator。CompositeImplicitAutograd kernel 确保

  1. 用户无需编写 (CPU) 可运行的 kernel。

  2. 确保 ExportProgram 的可重新追踪性(retrace-ability)。一旦重新追踪,后端 operator 将被分解回 pattern 中使用的 ATen operator。

示例

让我们假设一个包含 add 和 relu 两个 operator 的简单程序

def f(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    z = x + y
    return torch.ops.aten.relu.default(z)

下沉到 edge 方言后,它变为

graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
    %aten_add_tensor : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%arg0_1, %arg1_1), kwargs = {})
    %aten_relu_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.relu.default](args = (%aten_add_tensor,), kwargs = {})
    return (aten_relu_default,)

现在我想编写一个 pass,将 add 和 relu 合并到 add_relu 中,第一步是编写一个 pattern

# In the pattern, we can use edge ops and ATen ops interchangably
def pattern(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    z = torch.ops.aten.add.Tensor(x, y)
    out = torch.ops.aten.relu.default(z)
    return out

然后我们需要从融合后的 operator namespace 创建一个 operator library,然后将 decorator 应用到我们的 pattern 上

lib = Library("foo_namespace", "DEF")

@bind_pattern_to_op(lib, "add_relu(Tensor self, Tensor other) -> Tensor")
def pattern(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        z = torch.ops.aten.add.Tensor(x, y)
        out = torch.ops.aten.relu.default(z)
        return out

通过这种方式,我们将 pattern 注册为 add_relu 的一个 kernel,并已准备好在 pass 中使用。一个简单的 pass 如下所示

class AddReluFusionPass(ExportPass):
    def call(self, graph_module: GraphModule) -> PassResult:
        # decorator registers this pattern as a CompositeExplicitAutograd kernel, since there's no kernel registered before.
        @bind_pattern_to_op(lib, "add_relu")
        def pattern(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
            z = torch.ops.aten.add.Tensor(x, y)
            out = torch.ops.aten.relu.default(z)
            return out

        def replacement(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
            return torch.ops.foo_namespace.add_relu.default(x, y)

        subgraph_rewriter.replace_pattern(
            graph_module,
            _trace_and_lower_to_edge_ops(pattern),
            _trace_and_lower_to_edge_ops(replacement),
        )
        return PassResult(graph_module, True)

结果图如下所示

graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
    %foo_namespace_add_relu_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.foo_namespace.add_relu.default](args = (%arg0_1, %arg1_1), kwargs = {})
    return (foo_namespace_add_relu_default,)

Operator Set

以下是当前使用 bind_pattern_to_op API 的后端 operator。

  • executorch_prims::add.int(SymInt a, SymInt b) -> SymInt

    • pattern: builtin.add

    • backend: executor

  • executorch_prims::mul.int(SymInt a, SymInt b) -> SymInt

    • pattern: builtin.mul

    • backend: executor

  • executorch_prims::sub.int(SymInt a, SymInt b) -> SymInt

    • pattern: builtin.sub

    • backend: executor

  • executorch_prims::floordiv.int(SymInt a, SymInt b) -> SymInt

    • pattern: builtin.floordiv

    • backend: executor

  • executorch_prims::truediv.int(Scalar a, Scalar b) -> Scalar

    • pattern: builtin.div

    • backend: executor

  • executorch_prims::sym_float.Scalar(Scalar a) -> Scalar

    • pattern: builtin.float

    • backend: executor

  • executorch_prims::gt.int(SymInt a, SymInt b) -> bool

    • pattern: builtin.gt

    • backend: executor

  • executorch_prims::lt.int(SymInt a, SymInt b) -> bool

    • pattern: builtin.lt

    • backend: executor

  • executorch_prims::ge.int(SymInt a, SymInt b) -> bool

    • pattern: builtin.ge

    • backend: executor

  • executorch_prims::le.int(SymInt a, SymInt b) -> bool

    • pattern: builtin.le

    • backend: executor

  • executorch_prims::eq.int(SymInt a, SymInt b) -> bool

    • pattern: builtin.eq

    • backend: executor

  • executorch_prims::mod.Scalar(SymInt a, SymInt b) -> SymInt

    • pattern: builtin.divmod

    • backend: executor

  • executorch_prims::neg.Scalar(Scalar a) -> Scalar

    • pattern: operator.ne

    • backend: executor

  • quantized_decomposed::embedding_byte(Tensor weight, Tensor weight_scales, Tensor weight_zero_points, int weight_quant_min, int weight_quant_max, Tensor indices) -> Tensor

    • pattern: source

    • backend: quantization

  • quantized_decomposed::add(Tensor a, float a_scale, int a_zero_point, int a_quant_min, int a_quant_max, Tensor b, float b_scale, int b_zero_point, int b_quant_min, int b_quant_max, float out_scale, int out_zero_point, int out_quant_min, int out_quant_max) -> Tensor qc

    • pattern: source

    • backend: quantization

  • quantized_decomposed::add.scalar(Tensor qa, float a_scale, int a_zero_point, int a_quant_min, int a_quant_max, ScalarType a_dtype, Scalar b, float out_scale, int out_zero_point, int out_quant_min, int out_quant_max, ScalarType out_dtype) -> Tensor

    • pattern: source

    • backend: quantization

  • quantized_decomposed::add_relu(Tensor a, float a_scale, int a_zero_point, int a_quant_min, int a_quant_max, Tensor b, float b_scale, int b_zero_point, int b_quant_min, int b_quant_max, float out_scale, int out_zero_point, int out_quant_min, int out_quant_max) -> Tensor qc

    • pattern: source

    • backend: quantization

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得问题解答

查看资源