后端方言¶
概述¶
后端方言 是 边缘方言 的一个特殊变体,因为它在后端特定的图变换之后,包含后端特定的节点和元数据。后端方言是一个可选阶段,只有当我们想在图中引入后端感知时才需要。更具体地说,后端方言中的图可能包含运算符或委托降低的模块(参见 委托文档),这些运算符或模块仅对目标后端有意义。一个用例是,如果我们想将运算符融合到一个运算符中,例如,将连续的 addmm + relu 融合到一个 addmm_relu 运算符中,我们可以在这里这样做。
本文档描述了如何引入后端特定的运算符。
自定义运算符和后端特定运算符之间的区别:虽然自定义运算符出现在 eager 模式、ATen 方言和边缘方言中,但后端特定运算符仅在边缘方言之后发生的 Pass 中引入。
何时使用¶
此方言允许引入不符合规范 ATen 运算符集中定义的模式的运算符,并且这些运算符不会出现在上述任何方言(ATen 方言和边缘方言)中。如果您的用例满足以下一个或多个条件,请考虑使用后端运算符
您的后端提供了一个库,该库优化了某个等同于子图的运算符。例如,
linear_relu
(等效于 linear + relu),可以在特定后端上更快地执行。需要在已降低到后端的图模块之后重新跟踪该模块。当我们重新跟踪时,后端运算符可以转换回原始子图(在 ATen 方言中),而普通的自定义运算符则无法处理这种情况。
您的后端特定运算符没有通用的 CPU 内核,但只有一个用于特定后端的内核。使用后端运算符可以通过使用原始子图作为默认内核并保持图模块可运行来解决此问题。
或者,如果您担心这可能矫枉过正,并且只想使用更轻量级的东西,并且只需要编译器阶段的 Python 代码,则可以使用委托。
API¶
对于运算符/子图替换,常见的流程是
注册一个与子图具有相同输入和输出的运算符。此运算符将不具有特定于目标的实现(在编译阶段也不需要),但它需要给出与子图相同的结果。
创建一个模式,允许编译器找到子图并将其替换为替换项。
编写一个 Pass,将子图替换为新运算符。
为了简化此过程,我们提供了一个 API,以帮助 ExecuTorch 用户减少执行这些步骤的工作量。
Pass 基础设施入口点¶
为了将边缘运算符降低为后端运算符,Pass 将执行模式匹配以识别图中感兴趣的边缘运算符,然后将它们替换为等效的后端运算符。有两个 API 可以注册此类 Pass
transform()
。ExportProgram 上的一个 API,允许用户提供自定义 Pass。请注意,这不受任何验证器的保护,因此程序的健全性无法保证。ExecutorchBackendConfig.passes。如果在此处添加,则 Pass 将成为从后端方言到 ExecutorchProgram 的降低过程的一部分。
示例:一个这样的 Pass 是 QuantFusion。此 Pass 采用“规范量化模式”,即“dequant - some_op - quant”,并将此模式融合到单个后端特定的运算符中,即 quantized_decomposed::some_op
。另一个更简单的示例是 此处,我们将 sym_size
运算符替换为 ExecuTorch 可以理解的运算符
模式绑定装饰器¶
我们提供了一个装饰器 bind_pattern_to_op
,以帮助用户轻松地将其后端运算符注册到 EXIR 中。此装饰器接受
一个
torch.Library
对象,它指示此后端运算符所属的库或命名空间。名称或模式。如果我们在
torch.Library
对象中已经定义了后端运算符的模式,则只需要一个名称。否则,如果传入模式字符串,我们可以注册模式。
此装饰器应添加到我们尝试在边缘方言上匹配(然后降低为此后端运算符)的模式。这样,我们将此模式注册为此后端运算符的 CompositeImplicitAutograd
内核。
然后,可以从 Pass 访问/使用该运算符。CompositeImplicitAutograd
内核确保
用户无需编写(CPU)可运行内核。
确保
ExportProgram
的可重跟踪性。一旦重新跟踪,后端运算符将被分解为模式中使用的 ATen 运算符。
示例¶
让我们假设一个简单的程序,其中包含 add 和 relu 运算符
def f(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
z = x + y
return torch.ops.aten.relu.default(z)
降低到边缘方言后,它变为
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%aten_add_tensor : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%arg0_1, %arg1_1), kwargs = {})
%aten_relu_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.relu.default](args = (%aten_add_tensor,), kwargs = {})
return (aten_relu_default,)
现在我想编写一个 Pass,将 add
和 relu
合并为 add_relu
,第一步是编写一个模式
# In the pattern, we can use edge ops and ATen ops interchangably
def pattern(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
z = torch.ops.aten.add.Tensor(x, y)
out = torch.ops.aten.relu.default(z)
return out
然后我们需要从融合运算符命名空间创建一个运算符库,然后在我们的模式上使用装饰器
lib = Library("foo_namespace", "DEF")
@bind_pattern_to_op(lib, "add_relu(Tensor self, Tensor other) -> Tensor")
def pattern(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
z = torch.ops.aten.add.Tensor(x, y)
out = torch.ops.aten.relu.default(z)
return out
这样,我们将该模式注册为 add_relu
的内核,并且可以随时在 Pass 中使用。一个简单的 Pass 如下所示
class AddReluFusionPass(ExportPass):
def call(self, graph_module: GraphModule) -> PassResult:
# decorator registers this pattern as a CompositeExplicitAutograd kernel, since there's no kernel registered before.
@bind_pattern_to_op(lib, "add_relu")
def pattern(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
z = torch.ops.aten.add.Tensor(x, y)
out = torch.ops.aten.relu.default(z)
return out
def replacement(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return torch.ops.foo_namespace.add_relu.default(x, y)
subgraph_rewriter.replace_pattern(
graph_module,
_trace_and_lower_to_edge_ops(pattern),
_trace_and_lower_to_edge_ops(replacement),
)
return PassResult(graph_module, True)
结果图如下所示
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%foo_namespace_add_relu_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.foo_namespace.add_relu.default](args = (%arg0_1, %arg1_1), kwargs = {})
return (foo_namespace_add_relu_default,)
运算符集¶
以下是当前使用 bind_pattern_to_op
API 的后端运算符。
executorch_prims::add.int(SymInt a, SymInt b) -> SymInt
模式:builtin.add
后端:executor
executorch_prims::mul.int(SymInt a, SymInt b) -> SymInt
模式:builtin.mul
后端:executor
executorch_prims::sub.int(SymInt a, SymInt b) -> SymInt
模式:builtin.sub
后端:executor
executorch_prims::floordiv.int(SymInt a, SymInt b) -> SymInt
模式:builtin.floordiv
后端:executor
executorch_prims::truediv.int(Scalar a, Scalar b) -> Scalar
模式:builtin.div
后端:executor
executorch_prims::sym_float.Scalar(Scalar a) -> Scalar
模式:builtin.float
后端:executor
executorch_prims::gt.int(SymInt a, SymInt b) -> bool
模式:builtin.gt
后端:executor
executorch_prims::lt.int(SymInt a, SymInt b) -> bool
模式:builtin.lt
后端:executor
executorch_prims::ge.int(SymInt a, SymInt b) -> bool
模式:builtin.ge
后端:executor
executorch_prims::le.int(SymInt a, SymInt b) -> bool
模式:builtin.le
后端:executor
executorch_prims::eq.int(SymInt a, SymInt b) -> bool
模式:builtin.eq
后端:executor
executorch_prims::mod.Scalar(SymInt a, SymInt b) -> SymInt
模式:builtin.divmod
后端:executor
executorch_prims::neg.Scalar(Scalar a) -> Scalar
模式:operator.ne
后端:executor
quantized_decomposed::embedding_byte(Tensor weight, Tensor weight_scales, Tensor weight_zero_points, int weight_quant_min, int weight_quant_max, Tensor indices) -> Tensor
模式:source
后端:量化
quantized_decomposed::add(Tensor a, float a_scale, int a_zero_point, int a_quant_min, int a_quant_max, Tensor b, float b_scale, int b_zero_point, int b_quant_min, int b_quant_max, float out_scale, int out_zero_point, int out_quant_min, int out_quant_max) -> Tensor qc
模式:source
后端:量化
quantized_decomposed::add.scalar(Tensor qa, float a_scale, int a_zero_point, int a_quant_min, int a_quant_max, ScalarType a_dtype, Scalar b, float out_scale, int out_zero_point, int out_quant_min, int out_quant_max, ScalarType out_dtype) -> Tensor
模式:source
后端:量化
quantized_decomposed::add_relu(Tensor a, float a_scale, int a_zero_point, int a_quant_min, int a_quant_max, Tensor b, float b_scale, int b_zero_point, int b_quant_min, int b_quant_max, float out_scale, int out_zero_point, int out_quant_min, int out_quant_max) -> Tensor qc
模式:source
后端:量化