(Beta) 将 MobileNetV2 转换为 NNAPI¶
简介¶
本教程展示了如何准备计算机视觉模型以使用 Android 的神经网络 API (NNAPI)。NNAPI 提供了对许多现代 Android 设备上强大且高效的计算内核的访问。
PyTorch 的 NNAPI 目前处于“原型”阶段,只支持有限范围的运算符,但我们预计随着时间的推移会巩固集成并扩展我们的运算符支持。
模型准备¶
首先,我们必须准备我们的模型以使用 NNAPI 执行。此步骤在您的训练服务器或笔记本电脑上运行。要调用的关键转换函数是 torch.backends._nnapi.prepare.convert_model_to_nnapi
,但需要采取一些额外的步骤来确保模型结构正确。最值得注意的是,为了在某些加速器上运行模型,需要对模型进行量化。
您可以复制/粘贴整个 Python 脚本并运行它,或进行您自己的修改。默认情况下,它会将模型保存到 ~/mobilenetv2-nnapi/
。请先创建该目录。
#!/usr/bin/env python
import sys
import os
import torch
import torch.utils.bundled_inputs
import torch.utils.mobile_optimizer
import torch.backends._nnapi.prepare
import torchvision.models.quantization.mobilenet
from pathlib import Path
# This script supports 3 modes of quantization:
# - "none": Fully floating-point model.
# - "core": Quantize the core of the model, but wrap it a
# quantizer/dequantizer pair, so the interface uses floating point.
# - "full": Quantize the model, and use quantized tensors
# for input and output.
#
# "none" maintains maximum accuracy
# "core" sacrifices some accuracy for performance,
# but maintains the same interface.
# "full" maximized performance (with the same accuracy as "core"),
# but requires the application to use quantized tensors.
#
# There is a fourth option, not supported by this script,
# where we include the quant/dequant steps as NNAPI operators.
def make_mobilenetv2_nnapi(output_dir_path, quantize_mode):
quantize_core, quantize_iface = {
"none": (False, False),
"core": (True, False),
"full": (True, True),
}[quantize_mode]
model = torchvision.models.quantization.mobilenet.mobilenet_v2(pretrained=True, quantize=quantize_core)
model.eval()
# Fuse BatchNorm operators in the floating point model.
# (Quantized models already have this done.)
# Remove dropout for this inference-only use case.
if not quantize_core:
model.fuse_model()
assert type(model.classifier[0]) == torch.nn.Dropout
model.classifier[0] = torch.nn.Identity()
input_float = torch.zeros(1, 3, 224, 224)
input_tensor = input_float
# If we're doing a quantized model, we need to trace only the quantized core.
# So capture the quantizer and dequantizer, use them to prepare the input,
# and replace them with identity modules so we can trace without them.
if quantize_core:
quantizer = model.quant
dequantizer = model.dequant
model.quant = torch.nn.Identity()
model.dequant = torch.nn.Identity()
input_tensor = quantizer(input_float)
# Many NNAPI backends prefer NHWC tensors, so convert our input to channels_last,
# and set the "nnapi_nhwc" attribute for the converter.
input_tensor = input_tensor.contiguous(memory_format=torch.channels_last)
input_tensor.nnapi_nhwc = True
# Trace the model. NNAPI conversion only works with TorchScript models,
# and traced models are more likely to convert successfully than scripted.
with torch.no_grad():
traced = torch.jit.trace(model, input_tensor)
nnapi_model = torch.backends._nnapi.prepare.convert_model_to_nnapi(traced, input_tensor)
# If we're not using a quantized interface, wrap a quant/dequant around the core.
if quantize_core and not quantize_iface:
nnapi_model = torch.nn.Sequential(quantizer, nnapi_model, dequantizer)
model.quant = quantizer
model.dequant = dequantizer
# Switch back to float input for benchmarking.
input_tensor = input_float.contiguous(memory_format=torch.channels_last)
# Optimize the CPU model to make CPU-vs-NNAPI benchmarks fair.
model = torch.utils.mobile_optimizer.optimize_for_mobile(torch.jit.script(model))
# Bundle sample inputs with the models for easier benchmarking.
# This step is optional.
class BundleWrapper(torch.nn.Module):
def __init__(self, mod):
super().__init__()
self.mod = mod
def forward(self, arg):
return self.mod(arg)
nnapi_model = torch.jit.script(BundleWrapper(nnapi_model))
torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
model, [(torch.utils.bundled_inputs.bundle_large_tensor(input_tensor),)])
torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
nnapi_model, [(torch.utils.bundled_inputs.bundle_large_tensor(input_tensor),)])
# Save both models.
model._save_for_lite_interpreter(str(output_dir_path / ("mobilenetv2-quant_{}-cpu.pt".format(quantize_mode))))
nnapi_model._save_for_lite_interpreter(str(output_dir_path / ("mobilenetv2-quant_{}-nnapi.pt".format(quantize_mode))))
if __name__ == "__main__":
for quantize_mode in ["none", "core", "full"]:
make_mobilenetv2_nnapi(Path(os.environ["HOME"]) / "mobilenetv2-nnapi", quantize_mode)
运行基准测试¶
现在模型已经准备就绪,我们可以在 Android 设备上对其进行基准测试。有关详细信息,请参阅 我们的性能指南。性能最佳的模型可能是“完全量化”的模型:mobilenetv2-quant_full-cpu.pt
和 mobilenetv2-quant_full-nnapi.pt
。
由于这些模型已捆绑输入,因此我们可以按如下方式运行基准测试
./speed_benchmark_torch --pthreadpool_size=1 --model=mobilenetv2-quant_full-nnapi.pt --use_bundled_input=0 --warmup=5 --iter=200
调整增加线程池大小可以减少延迟,但会增加 CPU 使用率。省略该参数将使用每个大核一个线程。通过传递 --use_caching_allocator=true
,CPU 模型可以获得更好的性能(以内存使用量为代价)。
在主机上运行模型¶
我们现在可以使用 NNAPI 的参考实现,在 Linux 机器上运行模型。您需要从 Android 源代码构建 NNAPI 库。
确保您至少有 200 GB 的磁盘空间。
按照 这些说明 安装
repo
mkdir ~/android-nnapi && cd ~/android-nnapi
repo init -u https://android.googlesource.com/platform/manifest -b master
repo sync --network-only -j 16
repo sync -l
. build/envsetup.sh
lunch aosp_x86_64-eng
mm -j16 out/host/linux-x86/lib64/libneuralnetworks.so
使用 libneuralnetworks.so
的主机版本,您可以在 Linux 机器上运行 Pytorch NNAPI 模型。
#!/usr/bin/env python
import ctypes
import torch
from pathlib import Path
ctypes.cdll.LoadLibrary(Path.home() / "android-nnapi/out/host/linux-x86/lib64/libneuralnetworks.so")
model = torch.jit.load(Path.home() / "mobilenetv2-nnapi/mobilenetv2-quant_full-nnapi.pt")
print(model(*model.get_all_bundled_inputs()[0]))
集成¶
转换后的模型是普通的 TorchScript 模型。您可以在您的应用程序中像使用任何其他 PyTorch 模型一样使用它们。请参阅 https://pytorch.ac.cn/mobile/android/,了解如何在 Android 上使用 PyTorch 的介绍。
了解更多¶
在我们的 移动性能指南 中了解更多关于优化的信息。
MobileNetV2 来自 torchvision
有关 NNAPI 的信息