使用 XNNPACK 后端构建和运行 ExecuTorch¶
本教程将帮助您熟悉如何利用 ExecuTorch XNNPACK 代理来使用 CPU 硬件加速您的 ML 模型。它将介绍如何将模型导出并序列化为二进制文件,以针对 XNNPACK 代理后端,并在支持的目标平台上运行模型。为了快速入门,请使用 ExecuTorch 存储库中的脚本,该脚本包含有关导出和生成二进制文件的说明,这些文件用于演示流程的几个示例模型。
在本教程中,您将学习如何导出 XNNPACK 降级模型并在目标平台上运行它
将模型降级到 XNNPACK¶
import torch
import torchvision.models as models
from torch.export import export, ExportedProgram
from torchvision.models.mobilenetv2 import MobileNet_V2_Weights
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
from executorch.exir import EdgeProgramManager, ExecutorchProgramManager, to_edge
from executorch.exir.backend.backend_api import to_backend
mobilenet_v2 = models.mobilenetv2.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).eval()
sample_inputs = (torch.randn(1, 3, 224, 224), )
exported_program: ExportedProgram = export(mobilenet_v2, sample_inputs)
edge: EdgeProgramManager = to_edge(exported_program)
edge = edge.to_backend(XnnpackPartitioner())
我们将使用从 TorchVision 库下载的 MobileNetV2 预训练模型来完成此示例。降级模型的流程从导出模型 to_edge
后开始。我们使用 XnnpackPartitioner
调用 to_backend
api。分区器识别适合 XNNPACK 后端代理使用的子图。之后,识别出的子图将使用 XNNPACK 代理 flatbuffer 模式序列化,每个子图将被替换为对 XNNPACK 代理的调用。
>>> print(edge.exported_program().graph_module)
GraphModule(
(lowered_module_0): LoweredBackendModule()
(lowered_module_1): LoweredBackendModule()
)
def forward(self, arg314_1):
lowered_module_0 = self.lowered_module_0
executorch_call_delegate = torch.ops.higher_order.executorch_call_delegate(lowered_module_0, arg314_1); lowered_module_0 = arg314_1 = None
getitem = executorch_call_delegate[0]; executorch_call_delegate = None
aten_view_copy_default = executorch_exir_dialects_edge__ops_aten_view_copy_default(getitem, [1, 1280]); getitem = None
aten_clone_default = executorch_exir_dialects_edge__ops_aten_clone_default(aten_view_copy_default); aten_view_copy_default = None
lowered_module_1 = self.lowered_module_1
executorch_call_delegate_1 = torch.ops.higher_order.executorch_call_delegate(lowered_module_1, aten_clone_default); lowered_module_1 = aten_clone_default = None
getitem_1 = executorch_call_delegate_1[0]; executorch_call_delegate_1 = None
return (getitem_1,)
我们在降级后打印图,以显示插入以调用 XNNPACK 代理的新节点。被委托给 XNNPACK 的子图是每个调用站点上的第一个参数。可以观察到,大多数 convolution-relu-add
块和 linear
块能够被委托给 XNNPACK。我们还可以看到无法降级到 XNNPACK 代理的操作符,例如 clone
和 view_copy
。
exec_prog = edge.to_executorch()
with open("xnnpack_mobilenetv2.pte", "wb") as file:
exec_prog.write_to_file(file)
将模型降低到 XNNPACK 程序后,我们可以将其准备用于 executorch 并将模型保存为 .pte
文件。 .pte
是一种二进制格式,用于存储序列化后的 ExecuTorch 图。
将量化模型降低到 XNNPACK¶
XNNPACK 代理还可以执行对称量化模型。要了解量化流程并学习如何量化模型,请参阅 自定义量化 说明。为了本教程的方便,我们将利用 quantize()
Python 辅助函数,该函数已方便地添加到 executorch/executorch/examples
文件夹中。
from torch._export import capture_pre_autograd_graph
from executorch.exir import EdgeCompileConfig
mobilenet_v2 = models.mobilenetv2.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).eval()
sample_inputs = (torch.randn(1, 3, 224, 224), )
mobilenet_v2 = capture_pre_autograd_graph(mobilenet_v2, sample_inputs) # 2-stage export for quantization path
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
get_symmetric_quantization_config,
XNNPACKQuantizer,
)
def quantize(model, example_inputs):
"""This is the official recommended flow for quantization in pytorch 2.0 export"""
print(f"Original model: {model}")
quantizer = XNNPACKQuantizer()
# if we set is_per_channel to True, we also need to add out_variant of quantize_per_channel/dequantize_per_channel
operator_config = get_symmetric_quantization_config(is_per_channel=False)
quantizer.set_global(operator_config)
m = prepare_pt2e(model, quantizer)
# calibration
m(*example_inputs)
m = convert_pt2e(m)
print(f"Quantized model: {m}")
# make sure we can export to flat buffer
return m
quantized_mobilenetv2 = quantize(mobilenet_v2, sample_inputs)
量化需要两阶段导出。首先,我们使用 capture_pre_autograd_graph
API 在将模型传递给 quantize
实用程序函数之前捕获模型。完成量化步骤后,我们现在可以利用 XNNPACK 代理来降低量化后的导出模型图。从这里开始,该过程与非量化模型降低到 XNNPACK 的过程相同。
# Continued from earlier...
edge = to_edge(export(quantized_mobilenetv2, sample_inputs), compile_config=EdgeCompileConfig(_check_ir_validity=False))
edge = edge.to_backend(XnnpackPartitioner())
exec_prog = edge.to_executorch()
with open("qs8_xnnpack_mobilenetv2.pte", "wb") as file:
exec_prog.write_to_file(file)
使用 aot_compiler.py
脚本降低¶
我们还提供了一个脚本,可以快速降低和导出一些示例模型。您可以运行该脚本以生成降低的 fp32 和量化模型。此脚本仅用于方便,并执行与前两节中列出的步骤相同的步骤。
python -m examples.xnnpack.aot_compiler --model_name="mv2" --quantize --delegate
请注意上面的示例中,
the
-—model_name
指定要使用的模型the
-—quantize
标志控制是否应量化模型the
-—delegate
标志控制我们是否尝试将图的一部分降低到 XNNPACK 代理。
生成的模型文件将被命名为 [model_name]_xnnpack_[qs8/fp32].pte
,具体取决于提供的参数。
使用 CMake 运行 XNNPACK 模型¶
导出 XNNPACK 代理模型后,我们现在可以使用 CMake 尝试使用示例输入运行它。我们可以构建和使用 xnn_executor_runner,它是 ExecuTorch 运行时和 XNNPACK 后端的示例包装器。我们首先通过配置 CMake 构建来开始,如下所示
# cd to the root of executorch repo
cd executorch
# Get a clean cmake-out directory
rm- -rf cmake-out
mkdir cmake-out
# Configure cmake
cmake \
-DCMAKE_INSTALL_PREFIX=cmake-out \
-DCMAKE_BUILD_TYPE=Release \
-DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \
-DEXECUTORCH_BUILD_XNNPACK=ON \
-DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \
-DEXECUTORCH_ENABLE_LOGGING=1 \
-DPYTHON_EXECUTABLE=python \
-Bcmake-out .
然后,您可以使用以下命令构建运行时组件
cmake --build cmake-out -j9 --target install --config Release
现在,您应该能够在 ./cmake-out/backends/xnnpack/xnn_executor_runner
处找到构建的可执行文件,您可以使用生成的模型运行可执行文件,如下所示
./cmake-out/backends/xnnpack/xnn_executor_runner --model_path=./mv2_xnnpack_fp32.pte
# or to run the quantized variant
./cmake-out/backends/xnnpack/xnn_executor_runner --model_path=./mv2_xnnpack_q8.pte
使用 Buck 运行 XNNPACK 模型¶
或者,您可以使用 buck2
在您的主机平台上运行包含 XNNPACK 代理指令的 .pte
文件。您可以按照这里的说明安装 buck2。现在,您可以使用示例中提供的预构建的 xnn_executor_runner
运行它。这将在一些示例输入上运行模型。
buck2 run examples/xnnpack:xnn_executor_runner -- --model_path ./mv2_xnnpack_fp32.pte
# or to run the quantized variant
buck2 run examples/xnnpack:xnn_executor_runner -- --model_path ./mv2_xnnpack_q8.pte