快捷键

导出 IR 规范

导出 IR 是 torch.export 结果的中间表示 (IR)。要详细了解导出 IR,请阅读本 文档

导出的 IR 是一种规范,它包含以下部分

  1. 计算图模型的定义。

  2. 图中允许的运算符集。

**方言** 是一个用下面定义的运算符组成的导出 IR 图,但具有针对特定目的的附加属性(例如运算符集或元数据的限制)。

当前存在的 EXIR 方言是

这些方言代表了捕获的程序从程序捕获到转换为可执行格式所经历的阶段。例如,ExecuTorch 编译过程从 Python 程序捕获到 ATen 方言开始,然后将 ATen 方言转换为 Edge 方言,Edge 转换为后端,最后转换为用于执行的二进制格式。

ATen 方言

ATen 方言将用作 ExecuTorch 编译管道的入口点。这是急切模式 PyTorch 程序第一次成为导出 IR 图。在此阶段,执行函数化,删除所有张量别名和突变,并允许进行更灵活的图转换。此外,所有张量都被转换为连续格式。

该方言的目标是以尽可能忠实的方式捕获用户程序(同时保持有效的导出 IR)。在急切模式下用户调用的已注册自定义运算符将按原样保留在 ATen 方言中。但是,我们应该避免通过传递在图中添加自定义运算符。

目前,ATen 方言的功能是进一步降低到 Edge 方言。但是,将来我们可以将其视为其他导出用例的通用集成点。

ATen 方言属性

ATen 方言图是一个有效的导出 IR 图,具有以下附加属性

  1. 所有 call_function 节点中的运算符要么是 ATen 运算符(在 torch.ops.aten 命名空间中),要么是高阶运算符(如控制流运算符),要么是已注册的自定义运算符。已注册的自定义运算符是已注册到当前 PyTorch 急切模式运行时的运算符,通常使用 TORCH_LIBRARY 调用(意味着模式)。有关如何注册自定义运算符的详细信息,请参阅 此处

  2. 每个运算符还必须具有一个元内核。元内核是一个函数,它在给定输入张量的形状时,可以返回输出张量的形状。有关如何编写元内核的详细信息,请参阅 此处

  3. 输入值类型必须是“Pytree 可用的”。因此,输出类型也是 Pytree 可用的,因为所有运算符的输出都是 Pytree 可用的。

  4. ATen 方言的运算符可以选择使用动态数据类型、隐式类型提升和张量的隐式广播。

  5. 所有张量的内存格式均为 torch.contiguous_format

ATen 运算符定义

运算符集定义可在 此处 找到。

Edge 方言

该方言旨在引入对边缘设备有用的特殊化,但并非一定适用于通用(服务器)导出。但是,我们仍然保留进一步专门化到每个不同的硬件。换句话说,我们不想引入任何新的依赖于硬件的概念或数据;除了用户原始 Python 程序中已经存在的那些。

Edge 方言属性

Edge 方言图是一个有效的导出 IR 图,具有以下附加属性

  1. OpCall 节点中的所有运算符要么来自预定义的运算符集,称为 **“Edge 运算符”**,要么是已注册的自定义运算符。Edge 运算符是具有数据类型专门化的 ATen 运算符。这允许用户注册仅对某些数据类型有效的内核以减小二进制文件大小。

  2. 图的输入和输出,以及每个节点的输入和输出,都不能是标量。也就是说,所有标量类型(如浮点数、整数)都转换为张量。

使用 Edge 方言

Edge 方言在内存中使用 exir.EdgeProgramManager Python 类表示。这包含一个或多个 torch.export.ExportedProgram,它们包含方法的图表示。

import torch
from executorch import exir

class MyModule(torch.nn.Module):
    ...

a = MyModule()
tracing_inputs = (torch.rand(2, 2),)
aten_dialect_program = torch.export.export(a, tracing_inputs)
edge_dialect_program: exir.EdgeProgramManager = exir.to_edge(aten_dialect)
print(edge_dialect_program.exported_program)

此时,用户定义的图转换可以通过 edge_dialect_program.transform(pass) 来运行。顺序很重要。注意:如果自定义的 pass 涉及 node.target,请注意,此阶段的所有 node.target 都是 “Edge ops”(更多细节见下文),而不是像 ATen 方言中的 torch ops。有关编写 pass 的教程,请点击 此处。在执行完所有这些 pass 后,to_edge() 将确保图仍然有效。

边缘算子

如前所述,边缘算子是具有类型专门化的 ATen 核心算子。这意味着边缘算子的实例包含一组 dtype 约束,这些约束描述了 ExecuTorch 运行时及其 ATen 内核支持的所有张量 dtype。这些 dtype 约束以在 edge.yaml 中定义的 DSL 表示。以下是 dtype 约束的示例

- func: sigmoid
  namespace: edge
  inherits: aten::sigmoid
  type_alias:
    T0: [Bool, Byte, Char, Int, Long, Short]
    T1: [Double, Float]
    T2: [Float]
  type_constraint:
  - self: T0
    __ret_0: T2
  - self: T1
    __ret_0: T1

这表示如果 self 张量是 Bool, Byte, Char, Int, Long, Short 类型之一,则返回的张量将为 Float。如果 selfDouble, Float 之一,则返回的张量将具有相同的 dtype。

在收集并记录在 edge.yaml 中后,EXIR 会使用该文件并将约束加载到 EXIR 边缘算子中。这使得开发人员可以方便地了解边缘运算符架构中任何参数支持的 dtype。例如,我们可以执行以下操作

from executorch.exir.dialects._ops import ops as exir_ops # import dialects ops
sigmoid = exir_ops.edge.aten.sigmoid.default
print(sigmoid._schema)
# aten::sigmoid(Tensor self) -> Tensor
self_arg = sigmoid._schema.arguments[0]
_return = sigmoid._schema.returns[0]

print(self_arg.allowed_types)
# {torch.float32, torch.int8, torch.float64, torch.int16, torch.int32, torch.int64, torch.uint8, torch.bool}

print(_return.allowed_types)
# {torch.float32, torch.float64}

这些约束对于想要为该算子编写自定义内核的人员很有帮助。此外,EXIR 内部还提供了一个验证器,用于在自定义转换后检查图是否仍然符合这些 dtype 约束。

运算符集(WIP)

请查看 edge.yaml 以获取指定了 dtype 约束的运算符的完整列表。我们正在逐步扩展此运算符集,目标是为所有核心 ATen 运算符提供 dtype 约束。

后端方言

请参见此 文档

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源