• 教程 >
  • Android 上的图像分割 DeepLabV3
快捷方式

Android 上的图像分割 DeepLabV3

作者: Jeff Tang

审阅者: Jeremiah Chung

警告

PyTorch Mobile 现已不再积极维护。请查看 ExecuTorch,PyTorch 的全新设备端推理库。您还可以查看我们的 端到端工作流程 并查看 DeepLabV3 的源代码

简介

语义图像分割是一种计算机视觉任务,它使用语义标签来标记输入图像的特定区域。PyTorch 语义图像分割 DeepLabV3 模型 可用于使用 20 个语义类别(例如,自行车、公交车、汽车、狗和人)来标记图像区域。图像分割模型在自动驾驶和场景理解等应用中非常有用。

在本教程中,我们将提供一个分步指南,介绍如何在 Android 上准备和运行 PyTorch DeepLabV3 模型,从开始使用您可能想要在 Android 上使用的模型到最终拥有一个使用该模型的完整 Android 应用程序。我们还将介绍一些实用且通用的技巧,说明如何检查您接下来喜欢的预训练 PyTorch 模型是否可以在 Android 上运行,以及如何避免陷阱。

注意

在学习本教程之前,您应该查看 适用于 Android 的 PyTorch Mobile,并快速尝试 PyTorch Android Hello World 示例应用程序。本教程将超越通常部署在移动设备上的第一种模型——图像分类模型。本教程的完整代码可在 此处 获取。

学习目标

在本教程中,您将学习如何

  1. 将DeepLabV3模型转换为适用于Android部署的格式。

  2. 使用Python获取示例输入图像的模型输出,并将其与Android应用程序的输出进行比较。

  3. 构建新的Android应用程序或重用Android示例应用程序来加载转换后的模型。

  4. 将输入准备为模型期望的格式,并处理模型输出。

  5. 完成UI,重构,构建并运行应用程序,以查看图像分割的实际效果。

先决条件

  • PyTorch 1.6或1.7

  • torchvision 0.7或0.8

  • Android Studio 3.5.1或更高版本,并安装NDK

步骤

1. 将DeepLabV3模型转换为适用于Android部署的格式

在Android上部署模型的第一步是将模型转换为TorchScript格式。

注意

目前并非所有PyTorch模型都能够转换为TorchScript,因为模型定义可能使用了TorchScript中不存在的语言特性,而TorchScript是Python的一个子集。有关更多详细信息,请参阅脚本和优化指南

只需运行以下脚本即可生成脚本化的模型deeplabv3_scripted.pt

import torch

# use deeplabv3_resnet50 instead of resnet101 to reduce the model size
model = torch.hub.load('pytorch/vision:v0.7.0', 'deeplabv3_resnet50', pretrained=True)
model.eval()

scriptedm = torch.jit.script(model)
torch.jit.save(scriptedm, "deeplabv3_scripted.pt")

生成的deeplabv3_scripted.pt模型文件大小约为168MB。理想情况下,模型还应进行量化以大幅减小尺寸并在Android应用程序上实现更快的推理。为了对量化有一个大致的了解,请参阅量化指南及其中的资源链接。我们将在未来的教程或指南中详细介绍如何将称为训练后静态量化的量化工作流程正确应用于DeepLabV3模型。

2. 使用Python获取模型的示例输入和输出

现在我们有了脚本化的PyTorch模型,让我们使用一些示例输入进行测试,以确保模型在Android上能够正常工作。首先,让我们编写一个Python脚本,使用该模型进行推理并检查输入和输出。对于DeepLabV3模型的这个示例,我们可以重用步骤1和DeepLabV3模型中心站点中的代码。将以下代码片段添加到上面的代码中

from PIL import Image
from torchvision import transforms
input_image = Image.open("deeplab.jpg")
preprocess = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0)
with torch.no_grad():
    output = model(input_batch)['out'][0]

print(input_batch.shape)
print(output.shape)

这里下载deeplab.jpg,然后运行上面的脚本,您将看到模型输入和输出的形状。

torch.Size([1, 3, 400, 400])
torch.Size([21, 400, 400])

因此,如果您向Android上的模型提供相同大小为400x400的图像输入deeplab.jpg,则模型的输出应具有大小[21, 400, 400]。您还应该打印出输入和输出的实际数据(至少是开头部分),以便在下面的步骤4中与Android应用程序中运行模型时的实际输入和输出进行比较。

3. 构建新的Android应用程序或重用示例应用程序并加载模型

首先,请按照Android模型准备指南中的步骤3,在启用了PyTorch Mobile的Android Studio项目中使用我们的模型。因为本教程中使用的DeepLabV3和PyTorch Hello World Android示例中使用的MobileNet v2都是计算机视觉模型,所以您也可以获取Hello World示例代码库,以便更轻松地修改加载模型和处理输入和输出的代码。此步骤和步骤4的主要目标是确保在步骤1中生成的模型deeplabv3_scripted.pt确实可以在Android上正常工作。

现在,让我们将步骤2中使用的deeplabv3_scripted.ptdeeplab.jpg添加到Android Studio项目中,并将MainActivity中的onCreate方法修改为类似于以下代码:

Module module = null;
try {
  module = Module.load(assetFilePath(this, "deeplabv3_scripted.pt"));
} catch (IOException e) {
  Log.e("ImageSegmentation", "Error loading model!", e);
  finish();
}

然后在finish()行设置断点,构建并运行应用程序。如果应用程序没有在断点处停止,则表示步骤1中脚本化的模型已成功加载到Android上。

4. 处理模型输入和输出以进行模型推理

在上一步骤加载模型后,让我们验证它是否可以使用预期的输入并生成预期的输出。由于DeepLabV3模型的模型输入与Hello World示例中的MobileNet v2的模型输入相同,因此我们将重用Hello World中MainActivity.java文件中的部分代码进行输入处理。将MainActivity.java第50行到第73行之间的代码片段替换为以下代码

final Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(bitmap,
        TensorImageUtils.TORCHVISION_NORM_MEAN_RGB,
        TensorImageUtils.TORCHVISION_NORM_STD_RGB);
final float[] inputs = inputTensor.getDataAsFloatArray();

Map<String, IValue> outTensors =
    module.forward(IValue.from(inputTensor)).toDictStringKey();

// the key "out" of the output tensor contains the semantic masks
// see https://pytorch.ac.cn/hub/pytorch_vision_deeplabv3_resnet101
final Tensor outputTensor = outTensors.get("out").toTensor();
final float[] outputs = outputTensor.getDataAsFloatArray();

int width = bitmap.getWidth();
int height = bitmap.getHeight();

注意

对于DeepLabV3模型,模型输出是一个字典,因此我们使用toDictStringKey来正确提取结果。对于其他模型,模型输出也可能是一个单个张量或张量的元组,以及其他内容。

通过上述代码更改,您可以在final float[] inputsfinal float[] outputs之后设置断点,这些断点将输入张量和输出张量数据填充到浮点数组中,以便于调试。运行应用程序,当它在断点处停止时,将inputsoutputs中的数字与步骤2中看到的模型输入和输出数据进行比较,查看它们是否匹配。对于在Android和Python上运行的模型的相同输入,您应该获得相同的输出。

警告

由于某些Android模拟器的浮点实现问题,在Android模拟器上运行时,使用相同的图像输入可能会看到不同的模型输出。因此,最好在真实的Android设备上测试应用程序。

到目前为止,我们所做的只是确认我们感兴趣的模型可以被脚本化并在我们的Android应用程序中像在Python中一样正确运行。我们迄今为止介绍的用于在iOS应用程序中使用模型的步骤占据了我们应用程序开发的大部分时间,如果不是全部的话,这类似于数据预处理是典型机器学习项目中最繁重的工作。

5. 完成UI,重构,构建并运行应用程序

现在,我们准备完成应用程序和UI,以便实际看到处理后的结果作为新图像。输出处理代码应如下所示,添加到步骤4中的代码片段末尾

int[] intValues = new int[width * height];
// go through each element in the output of size [WIDTH, HEIGHT] and
// set different color for different classnum
for (int j = 0; j < width; j++) {
    for (int k = 0; k < height; k++) {
        // maxi: the index of the 21 CLASSNUM with the max probability
        int maxi = 0, maxj = 0, maxk = 0;
        double maxnum = -100000.0;
        for (int i=0; i < CLASSNUM; i++) {
            if (outputs[i*(width*height) + j*width + k] > maxnum) {
                maxnum = outputs[i*(width*height) + j*width + k];
                maxi = i; maxj = j; maxk= k;
            }
        }
        // color coding for person (red), dog (green), sheep (blue)
        // black color for background and other classes
        if (maxi == PERSON)
            intValues[maxj*width + maxk] = 0xFFFF0000; // red
        else if (maxi == DOG)
            intValues[maxj*width + maxk] = 0xFF00FF00; // green
        else if (maxi == SHEEP)
            intValues[maxj*width + maxk] = 0xFF0000FF; // blue
        else
            intValues[maxj*width + maxk] = 0xFF000000; // black
    }
}

上面代码中使用的常量在MainActivity类的开头定义

private static final int CLASSNUM = 21;
private static final int DOG = 12;
private static final int PERSON = 15;
private static final int SHEEP = 17;

此处的实现基于对DeepLabV3模型的理解,该模型为宽度*高度的输入图像输出大小为[21, 宽度, 高度]的张量。宽度*高度输出数组中的每个元素都是0到20之间的值(总共21个语义标签,如引言中所述),并且该值用于设置特定的颜色。此处的分割颜色编码基于概率最高的类,您可以扩展您自己数据集中的所有类的颜色编码。

在输出处理之后,您还需要调用以下代码将RGB intValues数组渲染到位图实例outputBitmap,然后再将其显示在ImageView

Bitmap bmpSegmentation = Bitmap.createScaledBitmap(bitmap, width, height, true);
Bitmap outputBitmap = bmpSegmentation.copy(bmpSegmentation.getConfig(), true);
outputBitmap.setPixels(intValues, 0, outputBitmap.getWidth(), 0, 0,
    outputBitmap.getWidth(), outputBitmap.getHeight());
imageView.setImageBitmap(outputBitmap);

此应用程序的UI也类似于Hello World的UI,除了您不需要TextView来显示图像分类结果。您还可以添加两个按钮SegmentRestart(如代码库中所示),以运行模型推理并在显示分割结果后显示回原始图像。

现在,当您在Android模拟器或优选的实际设备上运行应用程序时,您将看到如下屏幕

../_images/deeplabv3_android.png ../_images/deeplabv3_android2.png

总结

在本教程中,我们描述了将预训练的PyTorch DeepLabV3模型转换为适用于Android的格式所需的操作,以及如何确保模型可以在Android上成功运行。我们的重点是帮助您了解确认模型确实可以在Android上运行的过程。完整的代码库可从这里获取。

诸如量化以及通过迁移学习或在Android上使用您自己的模型等更高级主题将在未来的演示应用程序和教程中很快介绍。

文档

访问PyTorch的全面开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源