Android 食谱模型准备¶
此食谱演示了如何为 Android 应用程序准备 PyTorch MobileNet v2 图像分类模型,以及如何设置 Android 项目以使用移动就绪模型文件。
简介¶
在训练 PyTorch 模型或提供预训练模型后,它通常还没有准备好用于移动应用程序。它需要进行量化(请参阅量化食谱)、转换为 TorchScript 以便 Android 应用程序加载它,并针对移动应用程序进行优化。此外,Android 应用程序需要正确设置才能启用 PyTorch Mobile 库的使用,然后才能加载和使用模型进行推理。
步骤¶
1. 获取预训练和量化的 MobileNet v2 模型¶
要获取 MobileNet v2 量化模型,只需执行以下操作
import torchvision
model_quantized = torchvision.models.quantization.mobilenet_v2(pretrained=True, quantize=True)
2. 为移动应用程序脚本化和优化模型¶
使用 script 或 trace 方法将量化模型转换为 TorchScript 格式
import torch
dummy_input = torch.rand(1, 3, 224, 224)
torchscript_model = torch.jit.trace(model_quantized, dummy_input)
或
torchscript_model = torch.jit.script(model_quantized)
警告
trace 方法仅对跟踪期间执行的代码路径进行脚本化,因此对于包含决策分支的模型,它将无法正常工作。有关更多详细信息,请参阅脚本化和优化移动食谱。
然后优化 TorchScript 格式的模型以用于移动设备并保存它
from torch.utils.mobile_optimizer import optimize_for_mobile
torchscript_model_optimized = optimize_for_mobile(torchscript_model)
torch.jit.save(torchscript_model_optimized, "mobilenetv2_quantized.pt")
通过以上两个步骤中的总共 7 或 8 行代码(取决于是否调用了 script 或 trace 方法来获取模型的 TorchScript 格式),我们得到了一个可以添加到移动应用程序中的模型。
3. 在 Android 上添加模型和 PyTorch 库¶
在您当前的或新的 Android Studio 项目中,打开 build.gradle 文件,并添加以下两行代码(如果您计划使用 TorchVision 模型,则仅需要第二行)
implementation 'org.pytorch:pytorch_android:1.6.0'
implementation 'org.pytorch:pytorch_android_torchvision:1.6.0'
将模型文件 mobilenetv2_quantized.pt 拖放到项目的 assets 文件夹中。
就是这样!现在您可以使用 PyTorch 库和模型构建您的 Android 应用程序,并随时使用它。要实际编写代码来使用模型,请参考 PyTorch Mobile 的 Android 快速入门指南(含 HelloWorld 示例) 和 Android 黑客马拉松示例。