快捷方式

    注意

    PyTorch Mobile 已不再积极支持。请查看 ExecuTorch,这是 PyTorch 全新的设备端推理库。您也可以查阅 此页面,了解如何使用 ExecuTorch 构建 Android 应用。

    Android

    使用 HelloWorld 示例快速入门

    HelloWorld 是一个简单的图像分类应用,演示了如何使用 PyTorch Android API。此应用在静态图片上运行通过 TorchScript 序列化、来自 TorchVision 预训练的 resnet18 模型,该图片作为 Android asset 打包在应用内部。

    1. 模型准备

    让我们从模型准备开始。如果您熟悉 PyTorch,您可能已经知道如何训练和保存您的模型。如果您不熟悉,我们将使用一个预训练的图像分类模型 (MobileNetV2)。要安装它,请运行以下命令

    pip install torchvision
    

    要序列化模型,您可以使用 HelloWorld 应用根文件夹中的 python 脚本

    import torch
    import torchvision
    from torch.utils.mobile_optimizer import optimize_for_mobile
    
    model = torchvision.models.mobilenet_v2(pretrained=True)
    model.eval()
    example = torch.rand(1, 3, 224, 224)
    traced_script_module = torch.jit.trace(model, example)
    traced_script_module_optimized = optimize_for_mobile(traced_script_module)
    traced_script_module_optimized._save_for_lite_interpreter("app/src/main/assets/model.ptl")
    
    

    如果一切顺利,我们的模型文件 - model.ptl 应该已生成在 Android 应用的 assets 文件夹中。该文件将作为 asset 打包在 Android 应用内部,并可在设备上使用。

    您可以在 pytorch.org 上的教程中找到更多关于 TorchScript 的详细信息

    2. 从 github 克隆

    git clone https://github.com/pytorch/android-demo-app.git
    cd HelloWorldApp
    

    如果您已经安装了 Android SDKAndroid NDK,您可以使用以下命令将此应用安装到连接的 Android 设备或模拟器:

    ./gradlew installDebug
    

    我们建议您在 Android Studio 3.5.1+ 中打开此项目。目前 PyTorch Android 和演示应用使用的是 android gradle 插件的 3.5.0 版本,该版本仅由 Android Studio 3.5.1 及更高版本支持。使用 Android Studio,您将能够通过 Android Studio UI 安装 Android NDK 和 Android SDK。

    3. Gradle 依赖项

    PyTorch Android 作为 gradle 依赖项被添加到 HelloWorld 的 build.gradle 文件中

    repositories {
        jcenter()
    }
    
    dependencies {
        implementation 'org.pytorch:pytorch_android_lite:1.9.0'
        implementation 'org.pytorch:pytorch_android_torchvision:1.9.0'
    }
    

    其中 org.pytorch:pytorch_android 是主要的依赖项,包含 PyTorch Android API,以及支持所有 4 种 Android ABI (armeabi-v7a, arm64-v8a, x86, x86_64) 的 libtorch 本地库。本文档的后续部分将介绍如何仅针对特定的 Android ABI 列表重新构建它。

    org.pytorch:pytorch_android_torchvision - 一个附加库,提供将 android.media.Imageandroid.graphics.Bitmap 转换为张量的实用函数。

    4. 从 Android Asset 读取图像

    所有逻辑都在 org.pytorch.helloworld.MainActivity 中进行。第一步是使用标准的 Android API 将 image.jpg 读取到 android.graphics.Bitmap 中。

    Bitmap bitmap = BitmapFactory.decodeStream(getAssets().open("image.jpg"));
    

    5. 加载移动端模块

    Module module = Module.load(assetFilePath(this, "model.ptl"));
    

    org.pytorch.Module 代表 torch::jit::mobile::Module,可以使用 load 方法加载,该方法需要指定序列化到文件中的模型的路径。

    6. 准备输入

    Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(bitmap,
        TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, TensorImageUtils.TORCHVISION_NORM_STD_RGB);
    

    org.pytorch.torchvision.TensorImageUtilsorg.pytorch:pytorch_android_torchvision 库的一部分。TensorImageUtils#bitmapToFloat32Tensor 方法使用 android.graphics.Bitmap 作为源,创建 torchvision 格式的张量。

    所有预训练模型都期望输入图像以相同方式进行归一化,即形状为 (3 x H x W) 的 3 通道 RGB 图像的 mini-batch,其中 H 和 W 预期至少为 224。图像必须加载到 [0, 1] 范围,然后使用 mean = [0.485, 0.456, 0.406]std = [0.229, 0.224, 0.225] 进行归一化。

    inputTensor 的形状为 1x3xHxW,其中 HW 分别对应位图的高度和宽度。

    7. 运行推理

    Tensor outputTensor = module.forward(IValue.from(inputTensor)).toTensor();
    float[] scores = outputTensor.getDataAsFloatArray();
    

    org.pytorch.Module.forward 方法运行已加载模块的 forward 方法,并将结果作为形状为 1x1000org.pytorch.Tensor outputTensor 返回。

    8. 处理结果

    其内容通过 org.pytorch.Tensor.getDataAsFloatArray() 方法检索,该方法返回一个 Java 浮点数组,其中包含每个 ImageNet 类别的得分。

    之后,我们只需找到得分最高的索引,并从包含所有 ImageNet 类别的 ImageNetClasses.IMAGENET_CLASSES 数组中检索预测的类别名称。

    float maxScore = -Float.MAX_VALUE;
    int maxScoreIdx = -1;
    for (int i = 0; i < scores.length; i++) {
      if (scores[i] > maxScore) {
        maxScore = scores[i];
        maxScoreIdx = i;
      }
    }
    String className = ImageNetClasses.IMAGENET_CLASSES[maxScoreIdx];
    

    在以下部分中,您将找到 PyTorch Android API 的详细说明、一个更大的 演示应用的详细代码解析、API 的实现细节以及如何从源代码自定义和构建它。

    PyTorch 演示应用

    我们还在 同一 github 仓库中创建了另一个更复杂的 PyTorch Android 演示应用,它可以对相机输出进行图像分类,并进行文本分类。

    为了获取设备相机输出,它使用了 Android CameraX API。所有与 CameraX 相关的逻辑都已分离到 org.pytorch.demo.vision.AbstractCameraXActivity 类中。

    void setupCameraX() {
        final PreviewConfig previewConfig = new PreviewConfig.Builder().build();
        final Preview preview = new Preview(previewConfig);
        preview.setOnPreviewOutputUpdateListener(output -> mTextureView.setSurfaceTexture(output.getSurfaceTexture()));
    
        final ImageAnalysisConfig imageAnalysisConfig =
            new ImageAnalysisConfig.Builder()
                .setTargetResolution(new Size(224, 224))
                .setCallbackHandler(mBackgroundHandler)
                .setImageReaderMode(ImageAnalysis.ImageReaderMode.ACQUIRE_LATEST_IMAGE)
                .build();
        final ImageAnalysis imageAnalysis = new ImageAnalysis(imageAnalysisConfig);
        imageAnalysis.setAnalyzer(
            (image, rotationDegrees) -> {
              analyzeImage(image, rotationDegrees);
            });
    
        CameraX.bindToLifecycle(this, preview, imageAnalysis);
      }
    
      void analyzeImage(android.media.Image, int rotationDegrees)
    

    analyzeImage 方法处理相机输出,即 android.media.Image

    它使用上述的 TensorImageUtils.imageYUV420CenterCropToFloat32Tensor 方法将 YUV420 格式的 android.media.Image 转换为输入张量。

    从模型获取预测得分后,它会找到得分最高的 Top K 个类别并显示在 UI 上。

    语言处理示例

    另一个示例是自然语言处理,基于一个 LSTM 模型,该模型在一个 reddit 评论数据集上训练。逻辑在 TextClassificattionActivity 中进行。

    结果类别名称被打包在 TorchScript 模型中,并在初始模块初始化后立即初始化。该模块有一个 get_classes 方法,返回 List[str],可以使用 Module.runMethod(methodName) 方法调用。

        mModule = Module.load(moduleFileAbsoluteFilePath);
        IValue getClassesOutput = mModule.runMethod("get_classes");
    

    返回的 IValue 可以使用 IValue.toList() 转换为 IValue 的 Java 数组,然后使用 IValue.toStr() 处理为字符串数组。

        IValue[] classesListIValue = getClassesOutput.toList();
        String[] moduleClasses = new String[classesListIValue.length];
        int i = 0;
        for (IValue iv : classesListIValue) {
          moduleClasses[i++] = iv.toStr();
        }
    

    输入文本使用 UTF-8 编码转换为 Java 字节数组。Tensor.fromBlobUnsigned 从该字节数组创建一个 dtype=uint8 的张量。

        byte[] bytes = text.getBytes(Charset.forName("UTF-8"));
        final long[] shape = new long[]{1, bytes.length};
        final Tensor inputTensor = Tensor.fromBlobUnsigned(bytes, shape);
    

    运行模型的推理过程与之前的示例类似

    Tensor outputTensor = mModule.forward(IValue.from(inputTensor)).toTensor()
    

    之后,代码处理输出,找到得分最高的类别。

    更多 PyTorch Android 演示应用

    D2go

    D2Go 演示了一个 Python 脚本,该脚本创建了更轻更快、由 PyTorch 1.8、torchvision 0.9 和 Detectron2 支持的 Facebook D2Go 模型(内置了面向移动端的 SOTA 网络),以及一个使用该模型从照片、相机拍摄的照片或实时相机中检测物体的 Android 应用。这个演示应用还展示了如何使用原生预构建的 torchvision-ops 库。

    图像分割

    图像分割 演示了一个 Python 脚本,该脚本转换了 PyTorch DeepLabV3 模型,以及一个使用该模型进行图像分割的 Android 应用。

    物体检测

    物体检测 演示了如何转换流行的 YOLOv5 模型,并在一个 Android 应用中使用它来检测照片、相机拍摄的照片或实时相机中的物体。

    神经机器翻译

    神经机器翻译 演示了如何转换一个使用 PyTorch NMT 教程中的代码训练的序列到序列神经机器翻译模型,并在一个 Android 应用中使用该模型进行法英翻译。

    问答系统

    问答系统 演示了如何转换一个强大的 transformer QA 模型,并在一个 Android 应用中使用该模型回答关于 PyTorch Mobile 等问题。

    Vision Transformer

    Vision Transformer 演示了如何使用 Facebook 最新的 Vision Transformer DeiT 模型进行图像分类,以及如何转换另一个 Vision Transformer 模型并在一个 Android 应用中使用它进行手写数字识别。

    语音识别

    语音识别 演示了如何将 Facebook AI 的 wav2vec 2.0(语音识别领域领先模型之一)转换为 TorchScript,以及如何在 Android 应用中使用脚本化模型进行语音识别。

    视频分类

    TorchVideo 演示了如何在 Android 设备上使用新发布的 PyTorchVideo 中提供的预训练视频分类模型,以查看视频播放时每秒更新的视频分类结果,可用于测试视频、相册视频或实时视频。

    PyTorch Android 教程和精粹示例

    Android 上的图像分割 DeepLabV3

    关于如何在 Android 设备上准备和运行 PyTorch DeepLabV3 图像分割模型的详细分步教程。

    PyTorch Mobile 性能优化精粹示例

    在移动设备上使用 PyTorch 进行性能优化的精粹示例列表。

    构建使用 PyTorch Android 预构建库的 Android 原生应用

    学习如何从头开始构建使用 LibTorch C++ API 和带有自定义 C++ 算子的 TorchScript 模型的 Android 应用。

    融合模块精粹示例

    学习如何将一系列 PyTorch 模块融合成一个模块,以在量化前减小模型大小。

    移动端量化精粹示例

    学习如何减小模型大小并使其运行更快,同时尽量不损失准确性。

    为移动端进行脚本化和优化

    学习如何将模型转换为 TorchScript 并(可选)为移动应用进行优化。

    Android 模型准备精粹示例

    学习如何在 Android 项目中添加模型并使用 PyTorch Android 库。

    从源代码构建 PyTorch Android

    在某些情况下,您可能希望使用本地构建的 PyTorch Android,例如,您可能希望构建带有另一组算子的自定义 LibTorch 二进制文件,或者进行本地更改,或者尝试最新的 PyTorch 代码。

    为此,您可以使用 ./scripts/build_pytorch_android.sh 脚本。

    git clone https://github.com/pytorch/pytorch.git
    cd pytorch
    sh ./scripts/build_pytorch_android.sh
    

    工作流程包含几个步骤

    1. 为所有 4 种 Android ABI (armeabi-v7a, arm64-v8a, x86, x86_64) 构建 Android 版 libtorch

    2. 为这些构建结果创建符号链接:android/pytorch_android/src/main/jniLibs/${abi} 指向输出库所在的目录,android/pytorch_android/src/main/cpp/libtorch_include/${abi} 指向头文件所在的目录。这些目录用于构建 libpytorch_jni.so 库,作为 pytorch_android-release.aar 包的一部分,该包将在 Android 设备上加载。

    3. 最后在 android/pytorch_android 目录中运行 gradle,并执行 assembleRelease 任务

    脚本要求已安装 Android SDK、Android NDK、Java SDK 和 gradle。它们通过环境变量指定

    ANDROID_HOME - Android SDK 的路径

    ANDROID_NDK - Android NDK 的路径。推荐使用 NDK 21.x。

    GRADLE_HOME - gradle 的路径

    JAVA_HOME - JAVA JDK 的路径

    构建成功后,您应该会看到结果为一个 aar 文件

    $ find android -type f -name *aar
    android/pytorch_android/build/outputs/aar/pytorch_android-release.aar
    android/pytorch_android_torchvision/build/outputs/aar/pytorch_android_torchvision-release.aar
    

    使用从源代码或 nightly 构建的 PyTorch Android 库

    首先将上面构建的两个 aar 文件,或者从 这里这里 的 nightly 构建的 PyTorch Android 仓库下载的文件添加到 Android 项目的 lib 文件夹中,然后在项目的 app build.gradle 文件中添加

    allprojects {
        repositories {
            flatDir {
                dirs 'libs'
            }
        }
    }
    
    dependencies {
    
        // if using the libraries built from source
        implementation(name:'pytorch_android-release', ext:'aar')
        implementation(name:'pytorch_android_torchvision-release', ext:'aar')
    
        // if using the nightly built libraries downloaded above, for example the 1.8.0-snapshot on Jan. 21, 2021
        // implementation(name:'pytorch_android-1.8.0-20210121.092759-172', ext:'aar')
        // implementation(name:'pytorch_android_torchvision-1.8.0-20210121.092817-173', ext:'aar')
    
        ...
        implementation 'com.android.support:appcompat-v7:28.0.0'
        implementation 'com.facebook.fbjni:fbjni-java-only:0.0.3'
    }
    

    此外,我们还必须添加我们 aar 文件的所有传递依赖项。由于 pytorch_android 依赖于 com.android.support:appcompat-v7:28.0.0androidx.appcompat:appcompat:1.2.0,我们需要其中之一。(如果使用 maven 依赖项,它们将自动从 pom.xml 中添加)。

    使用 nightly PyTorch Android 库

    除了使用从源代码构建或从上一节链接下载的 aar 文件外,您还可以通过在 app build.gradle 文件中添加 maven url 和 nightly 库的实现来使用 nightly 构建的 Android PyTorch 和 TorchVision 库,如下所示

    repositories {
        maven {
            url "https://oss.sonatype.org/content/repositories/snapshots"
        }
    }
    
    dependencies {
        ...
        implementation 'org.pytorch:pytorch_android:1.8.0-SNAPSHOT'
        implementation 'org.pytorch:pytorch_android_torchvision:1.8.0-SNAPSHOT'
    }
    

    如果您不需要进行任何本地更改,这是尝试最新的 PyTorch 代码和 Android 库的最简单方法。但请注意,您可能需要使用最新的 PyTorch(无论是使用最新的 PyTorch 代码还是通过 pip install --pre torch torchvision -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html 等命令快速安装 nightly 版本)来构建在移动设备上使用的模型,以避免在移动设备上运行模型时可能出现的模型版本不匹配错误。

    自定义构建

    为了减小二进制文件的大小,您可以对 PyTorch Android 进行自定义构建,只包含您的模型所需的算子集。这包括两个步骤:从您的模型中准备算子列表,使用指定列表重新构建 PyTorch Android。

    1. 验证您的 PyTorch 版本是否为 1.4.0 或更高。您可以通过检查 torch.__version__ 的值来做到这一点。

    2. 算子列表的准备

    您序列化的 TorchScript 模型的算子列表可以使用 python api 函数 torch.jit.export_opnames() 以 yaml 格式准备。要导出模型(例如 MobileNetV2)中的算子,请运行以下 Python 代码行

    # Dump list of operators used by MobileNetV2:
    import torch, yaml
    model = torch.jit.load('MobileNetV2.pt')
    ops = torch.jit.export_opnames(model)
    with open('MobileNetV2.yaml', 'w') as output:
        yaml.dump(ops, output)
    

    3. 使用准备好的算子列表构建 PyTorch Android。

    要使用准备好的 yaml 算子列表构建 PyTorch Android,请在环境变量 SELECTED_OP_LIST 中指定它。此外,在参数中指定应构建哪些 Android ABI;默认情况下构建所有 4 个 Android ABI。

    # Build PyTorch Android library customized for MobileNetV2:
    SELECTED_OP_LIST=MobileNetV2.yaml scripts/build_pytorch_android.sh arm64-v8a
    

    构建成功后,您可以按照本教程上一节(从源代码构建 PyTorch Android)的步骤将结果 aar 文件集成到您的 Android gradle 项目中。

    使用 PyTorch JIT 解释器

    PyTorch JIT 解释器是 1.9 之前的默认解释器(它是我们的 PyTorch 解释器的一个版本,效率不如新版本)。它在 1.9 中仍将得到支持,并且可以通过 build.gradle 使用

    repositories {
        jcenter()
    }
    
    dependencies {
        implementation 'org.pytorch:pytorch_android:1.9.0'
        implementation 'org.pytorch:pytorch_android_torchvision:1.9.0'
    }
    

    Android 教程

    观看以下视频,PyTorch 合作伙伴工程师 Brad Heintz 将逐步讲解如何为 Android 项目设置 PyTorch Runtime

    PyTorch Mobile Runtime for Android

    相应的代码可以在这里找到。

    查看我们的移动性能优化精粹示例,其中介绍了如何优化您的模型以及如何通过基准测试检查优化是否有效。

    此外,按照此精粹示例学习如何构建使用 PyTorch 预构建库的 Android 原生应用。

    API 文档

    您可以在Javadoc中找到更多关于 PyTorch Android API 的详细信息。