使用 XNNPACK 后端构建和运行 ExecuTorch¶
本教程将帮助您熟悉如何利用 ExecuTorch XNNPACK Delegate 来使用 CPU 硬件加速您的 ML 模型。本教程将介绍如何将模型导出并序列化为二进制文件,以 XNNPACK Delegate 后端为目标,并在受支持的目标平台运行模型。为了快速入门,请使用 ExecuTorch 代码仓库中提供的脚本,其中包含用于导出和生成一些示例模型二进制文件的说明,这些示例模型展示了整个流程。
在本教程中,您将学习如何导出经 XNNPACK 降级处理后的模型,并在目标平台运行它。
将模型降级处理到 XNNPACK¶
import torch
import torchvision.models as models
from torch.export import export, ExportedProgram
from torchvision.models.mobilenetv2 import MobileNet_V2_Weights
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
from executorch.exir import EdgeProgramManager, ExecutorchProgramManager, to_edge_transform_and_lower
from executorch.exir.backend.backend_api import to_backend
mobilenet_v2 = models.mobilenetv2.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).eval()
sample_inputs = (torch.randn(1, 3, 224, 224), )
exported_program: ExportedProgram = export(mobilenet_v2, sample_inputs)
edge: EdgeProgramManager = to_edge_transform_and_lower(
exported_program,
partitioner=[XnnpackPartitioner()],
)
我们将使用从 TorchVision 库下载的 MobileNetV2 预训练模型来讲解这个示例。模型的降级处理流程在将模型 to_edge
导出后开始。我们调用 to_backend
API,并传入 XnnpackPartitioner
。Partitioner 会识别适合 XNNPACK 后端 Delegate 使用的子图。然后,识别出的子图将使用 XNNPACK Delegate flatbuffer 模式进行序列化,并且每个子图都将被替换为对 XNNPACK Delegate 的调用。
>>> print(edge.exported_program().graph_module)
GraphModule(
(lowered_module_0): LoweredBackendModule()
(lowered_module_1): LoweredBackendModule()
)
def forward(self, b_features_0_1_num_batches_tracked, ..., x):
lowered_module_0 = self.lowered_module_0
lowered_module_1 = self.lowered_module_1
executorch_call_delegate_1 = torch.ops.higher_order.executorch_call_delegate(lowered_module_1, x); lowered_module_1 = x = None
getitem_53 = executorch_call_delegate_1[0]; executorch_call_delegate_1 = None
aten_view_copy_default = executorch_exir_dialects_edge__ops_aten_view_copy_default(getitem_53, [1, 1280]); getitem_53 = None
aten_clone_default = executorch_exir_dialects_edge__ops_aten_clone_default(aten_view_copy_default); aten_view_copy_default = None
executorch_call_delegate = torch.ops.higher_order.executorch_call_delegate(lowered_module_0, aten_clone_default); lowered_module_0 = aten_clone_default = None
getitem_52 = executorch_call_delegate[0]; executorch_call_delegate = None
return (getitem_52,)
我们在上面降级处理后打印了图,以显示为调用 XNNPACK Delegate 而插入的新节点。被委托给 XNNPACK 的子图是每个调用位置的第一个参数。可以观察到,大多数 convolution-relu-add
块和 linear
块能够被委托给 XNNPACK。我们还可以看到未能降级到 XNNPACK Delegate 的算子,例如 clone
和 view_copy
。
exec_prog = edge.to_executorch()
with open("xnnpack_mobilenetv2.pte", "wb") as file:
exec_prog.write_to_file(file)
将模型降级处理到 XNNPACK Program 后,我们可以为其准备 ExecuTorch,并将模型保存为 .pte
文件。.pte
是一种二进制格式,用于存储序列化的 ExecuTorch 图。
将量化模型降级处理到 XNNPACK¶
XNNPACK delegate 也可以执行对称量化模型。要理解量化流程并学习如何量化模型,请参阅自定义量化 (Custom Quantization) 说明。在本教程中,我们将利用方便地添加到 executorch/executorch/examples
文件夹中的 quantize()
python 辅助函数。
from torch.export import export_for_training
from executorch.exir import EdgeCompileConfig, to_edge_transform_and_lower
mobilenet_v2 = models.mobilenetv2.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).eval()
sample_inputs = (torch.randn(1, 3, 224, 224), )
mobilenet_v2 = export_for_training(mobilenet_v2, sample_inputs).module() # 2-stage export for quantization path
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import (
get_symmetric_quantization_config,
XNNPACKQuantizer,
)
def quantize(model, example_inputs):
"""This is the official recommended flow for quantization in pytorch 2.0 export"""
print(f"Original model: {model}")
quantizer = XNNPACKQuantizer()
# if we set is_per_channel to True, we also need to add out_variant of quantize_per_channel/dequantize_per_channel
operator_config = get_symmetric_quantization_config(is_per_channel=False)
quantizer.set_global(operator_config)
m = prepare_pt2e(model, quantizer)
# calibration
m(*example_inputs)
m = convert_pt2e(m)
print(f"Quantized model: {m}")
# make sure we can export to flat buffer
return m
quantized_mobilenetv2 = quantize(mobilenet_v2, sample_inputs)
量化需要分两阶段导出。首先,我们使用 export_for_training
API 在将模型交给 quantize
实用函数之前捕获模型。执行量化步骤后,我们现在可以利用 XNNPACK delegate 对量化后的导出模型图进行降级处理。从这里开始,流程与非量化模型降级处理到 XNNPACK 的过程相同。
# Continued from earlier...
edge = to_edge_transform_and_lower(
export(quantized_mobilenetv2, sample_inputs),
compile_config=EdgeCompileConfig(_check_ir_validity=False),
partitioner=[XnnpackPartitioner()]
)
exec_prog = edge.to_executorch()
with open("qs8_xnnpack_mobilenetv2.pte", "wb") as file:
exec_prog.write_to_file(file)
使用 aot_compiler.py
脚本进行降级处理¶
我们还提供了一个脚本,可以快速对几个示例模型进行降级处理和导出。您可以运行该脚本来生成降级处理后的 fp32 和量化模型。此脚本仅为方便起见而提供,执行的步骤与前两节中列出的步骤完全相同。
python -m examples.xnnpack.aot_compiler --model_name="mv2" --quantize --delegate
请注意在上面的示例中,
-—model_name
指定要使用的模型-—quantize
标志控制模型是否应该被量化-—delegate
标志控制我们是否尝试将图的部分内容降级处理到 XNNPACK delegate。
生成的模型文件将根据提供的参数命名为 [model_name]_xnnpack_[qs8/fp32].pte
。
使用 CMake 运行 XNNPACK 模型¶
导出 XNNPACK Delegated 模型后,我们现在可以使用 CMake 尝试使用示例输入运行它。我们可以构建和使用 xnn_executor_runner,它是 ExecuTorch 运行时和 XNNPACK 后端的一个示例包装器。我们首先通过如下方式配置 CMake 构建
# cd to the root of executorch repo
cd executorch
# Get a clean cmake-out directory
./install_executorch.sh --clean
mkdir cmake-out
# Configure cmake
cmake \
-DCMAKE_INSTALL_PREFIX=cmake-out \
-DCMAKE_BUILD_TYPE=Release \
-DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \
-DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \
-DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \
-DEXECUTORCH_BUILD_XNNPACK=ON \
-DEXECUTORCH_ENABLE_LOGGING=ON \
-DPYTHON_EXECUTABLE=python \
-Bcmake-out .
然后您可以使用以下命令构建运行时组件
cmake --build cmake-out -j9 --target install --config Release
现在您应该能够在 ./cmake-out/backends/xnnpack/xnn_executor_runner
找到构建好的可执行文件,您可以通过如下方式运行该可执行文件并传入您生成的模型
./cmake-out/backends/xnnpack/xnn_executor_runner --model_path=./mv2_xnnpack_fp32.pte
# or to run the quantized variant
./cmake-out/backends/xnnpack/xnn_executor_runner --model_path=./mv2_xnnpack_q8.pte
使用 XNNPACK 后端进行构建和链接¶
您可以构建 XNNPACK 后端 CMake target,并将其与您的应用程序二进制文件链接,例如 Android 或 iOS 应用程序。有关更多信息,您可以接下来查看此资源。
性能分析¶
要在 xnn_executor_runner
中启用性能分析,请将标志 -DEXECUTORCH_ENABLE_EVENT_TRACER=ON
和 -DEXECUTORCH_BUILD_DEVTOOLS=ON
传递给构建命令(添加 -DENABLE_XNNPACK_PROFILING=ON
以获取更多详细信息)。这将在使用推理时启用 ETDump 生成,并启用用于性能分析的命令行标志(详情请参阅 xnn_executor_runner --help
)。