快捷方式

    注意

    PyTorch Mobile 不再得到积极支持。请查看 ExecuTorch,PyTorch 的全新设备上推理库。您还可以查看 此页面 以了解有关如何使用 ExecuTorch 构建 Android 应用程序的更多信息。

    安卓

    使用 HelloWorld 示例快速入门

    HelloWorld 是一个简单的图像分类应用程序,演示了如何使用 PyTorch Android API。此应用程序在静态图像上运行 TorchScript 序列化 TorchVision 预训练的 resnet18 模型,该图像作为 android 资产打包在应用程序中。

    1. 模型准备

    让我们从模型准备开始。如果您熟悉 PyTorch,您可能已经知道如何训练和保存模型。如果您不知道,我们将使用预训练的图像分类模型 (MobileNetV2)。要安装它,请运行以下命令

    pip install torchvision
    

    要序列化模型,您可以使用 HelloWorld 应用程序根文件夹中的 python 脚本

    import torch
    import torchvision
    from torch.utils.mobile_optimizer import optimize_for_mobile
    
    model = torchvision.models.mobilenet_v2(pretrained=True)
    model.eval()
    example = torch.rand(1, 3, 224, 224)
    traced_script_module = torch.jit.trace(model, example)
    traced_script_module_optimized = optimize_for_mobile(traced_script_module)
    traced_script_module_optimized._save_for_lite_interpreter("app/src/main/assets/model.ptl")
    
    

    如果一切正常,我们应该在 android 应用程序的 assets 文件夹中生成模型 - model.ptl。它将作为 asset 打包在 android 应用程序中,可以在设备上使用。

    有关 TorchScript 的更多详细信息,请参见 pytorch.org 上的教程

    2. 从 github 克隆

    git clone https://github.com/pytorch/android-demo-app.git
    cd HelloWorldApp
    

    如果 Android SDKAndroid NDK 已经安装,您可以使用以下命令将此应用程序安装到连接的 android 设备或模拟器上

    ./gradlew installDebug
    

    我们建议您在 Android Studio 3.5.1+ 中打开此项目。目前,PyTorch Android 和演示应用程序使用 android gradle 插件版本 3.5.0,该插件仅受 Android Studio 版本 3.5.1 及更高版本支持。使用 Android Studio,您可以使用 Android Studio UI 安装 Android NDK 和 Android SDK。

    3. Gradle 依赖项

    Pytorch android 已作为 gradle 依赖项 添加到 HelloWorld 中的 build.gradle 文件

    repositories {
        jcenter()
    }
    
    dependencies {
        implementation 'org.pytorch:pytorch_android_lite:1.9.0'
        implementation 'org.pytorch:pytorch_android_torchvision:1.9.0'
    }
    

    其中 org.pytorch:pytorch_android 是带有 PyTorch Android API 的主要依赖项,包括针对所有 4 个 android abi(armeabi-v7a、arm64-v8a、x86、x86_64)的 libtorch 本机库。在本文档的后面,您将找到如何仅针对特定 android abi 列表重建它的方法。

    org.pytorch:pytorch_android_torchvision - 额外的库,提供用于将 android.media.Imageandroid.graphics.Bitmap 转换为张量的实用函数。

    4. 从 Android 资产读取图像

    所有逻辑都在 org.pytorch.helloworld.MainActivity 中。作为第一步,我们使用标准 Android API 将 image.jpg 读取到 android.graphics.Bitmap 中。

    Bitmap bitmap = BitmapFactory.decodeStream(getAssets().open("image.jpg"));
    

    5. 加载 Mobile 模块

    Module module = Module.load(assetFilePath(this, "model.ptl"));
    

    org.pytorch.Module 代表 torch::jit::mobile::Module,可以使用 load 方法加载,该方法指定序列化到文件模型的文件路径。

    6. 准备输入

    Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(bitmap,
        TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, TensorImageUtils.TORCHVISION_NORM_STD_RGB);
    

    org.pytorch.torchvision.TensorImageUtilsorg.pytorch:pytorch_android_torchvision 库的一部分。该 TensorImageUtils#bitmapToFloat32Tensor 方法使用 android.graphics.Bitmap 作为源创建 torchvision 格式 的张量。

    所有预训练模型都期望输入图像以相同的方式进行归一化,即形状为 (3 x H x W) 的 3 通道 RGB 图像的小批量,其中 H 和 W 至少应为 224。图像必须加载到 [0, 1] 范围内,然后使用 mean = [0.485, 0.456, 0.406]std = [0.229, 0.224, 0.225] 进行归一化

    inputTensor 的形状为 1x3xHxW,其中 HW 分别是位图高度和宽度。

    7. 运行推理

    Tensor outputTensor = module.forward(IValue.from(inputTensor)).toTensor();
    float[] scores = outputTensor.getDataAsFloatArray();
    

    org.pytorch.Module.forward 方法运行加载的模块的 forward 方法,并将结果作为 org.pytorch.Tensor outputTensor 获得,形状为 1x1000

    8. 处理结果

    它的内容使用 org.pytorch.Tensor.getDataAsFloatArray() 方法检索,该方法返回包含每个 ImageNet 类别的得分的浮点数的 java 数组。

    之后,我们只需找到得分最高的索引,并从 ImageNetClasses.IMAGENET_CLASSES 数组中检索预测的类别名称,该数组包含所有 ImageNet 类别。

    float maxScore = -Float.MAX_VALUE;
    int maxScoreIdx = -1;
    for (int i = 0; i < scores.length; i++) {
      if (scores[i] > maxScore) {
        maxScore = scores[i];
        maxScoreIdx = i;
      }
    }
    String className = ImageNetClasses.IMAGENET_CLASSES[maxScoreIdx];
    

    在以下部分,您将找到对 PyTorch Android API 的详细解释,一个更大的 演示应用程序 的代码演练,API 的实现细节,如何自定义和从源代码构建 API。

    PyTorch 演示应用程序

    我们还创建了另一个更复杂的 PyTorch Android 演示应用程序,该应用程序从相机输出执行图像分类,并在 同一个 github 仓库 中进行文本分类。

    要获取设备相机输出,它使用 Android CameraX API。所有与 CameraX 相关的逻辑都已分离到 org.pytorch.demo.vision.AbstractCameraXActivity 类中。

    void setupCameraX() {
        final PreviewConfig previewConfig = new PreviewConfig.Builder().build();
        final Preview preview = new Preview(previewConfig);
        preview.setOnPreviewOutputUpdateListener(output -> mTextureView.setSurfaceTexture(output.getSurfaceTexture()));
    
        final ImageAnalysisConfig imageAnalysisConfig =
            new ImageAnalysisConfig.Builder()
                .setTargetResolution(new Size(224, 224))
                .setCallbackHandler(mBackgroundHandler)
                .setImageReaderMode(ImageAnalysis.ImageReaderMode.ACQUIRE_LATEST_IMAGE)
                .build();
        final ImageAnalysis imageAnalysis = new ImageAnalysis(imageAnalysisConfig);
        imageAnalysis.setAnalyzer(
            (image, rotationDegrees) -> {
              analyzeImage(image, rotationDegrees);
            });
    
        CameraX.bindToLifecycle(this, preview, imageAnalysis);
      }
    
      void analyzeImage(android.media.Image, int rotationDegrees)
    

    其中,analyzeImage 方法处理相机输出,android.media.Image

    它使用前面提到的 TensorImageUtils.imageYUV420CenterCropToFloat32Tensor 方法将 android.media.Image 转换为 YUV420 格式的输入张量。

    在从模型中获得预测得分后,它会找到得分最高的 top K 类,并在 UI 上显示。

    语言处理示例

    另一个示例是自然语言处理,基于在 reddit 评论数据集上训练的 LSTM 模型。逻辑发生在 TextClassificattionActivity 中。

    结果类别名称打包在 TorchScript 模型中,并在初始模块初始化后立即初始化。该模块具有一个 get_classes 方法,该方法返回 List[str],可以使用 Module.runMethod(methodName) 方法调用。

        mModule = Module.load(moduleFileAbsoluteFilePath);
        IValue getClassesOutput = mModule.runMethod("get_classes");
    

    返回的 IValue 可以使用 IValue.toList() 转换为 java 的 IValue 数组,并使用 IValue.toStr() 处理为字符串数组。

        IValue[] classesListIValue = getClassesOutput.toList();
        String[] moduleClasses = new String[classesListIValue.length];
        int i = 0;
        for (IValue iv : classesListIValue) {
          moduleClasses[i++] = iv.toStr();
        }
    

    输入的文本将使用 UTF-8 编码转换为 Java 字节数组。 Tensor.fromBlobUnsigned 从该字节数组创建 dtype=uint8 的张量。

        byte[] bytes = text.getBytes(Charset.forName("UTF-8"));
        final long[] shape = new long[]{1, bytes.length};
        final Tensor inputTensor = Tensor.fromBlobUnsigned(bytes, shape);
    

    运行模型的推理与之前的示例类似。

    Tensor outputTensor = mModule.forward(IValue.from(inputTensor)).toTensor()
    

    之后,代码处理输出,找到得分最高的类别。

    更多 PyTorch Android 演示应用程序

    D2go

    D2Go 演示了一个 Python 脚本,该脚本创建了更轻量级且更快的 Facebook D2Go 模型,该模型由 PyTorch 1.8、torchvision 0.9 和 Detectron2 提供支持,其中包含针对移动设备的 SOTA 网络,以及一个 Android 应用程序,该应用程序使用它从照片、相机拍摄的照片或实时相机中检测物体。 此演示应用程序还展示了如何使用本机预构建的 torchvision-ops 库。

    图像分割

    图像分割 演示了一个 Python 脚本,该脚本将 PyTorch DeepLabV3 模型转换为 Android 应用程序,该应用程序使用该模型分割图像。

    目标检测

    目标检测 演示了如何转换流行的 YOLOv5 模型,并在 Android 应用程序中使用它从照片、相机拍摄的照片或实时相机中检测物体。

    神经机器翻译

    神经机器翻译 演示了如何转换一个序列到序列神经机器翻译模型,该模型是使用 PyTorch NMT 教程 中的代码训练的,以及如何在 Android 应用程序中使用该模型进行法语-英语翻译。

    问答

    问答 演示了如何转换一个强大的 Transformer QA 模型,以及如何在 Android 应用程序中使用该模型来回答有关 PyTorch Mobile 等内容的问题。

    视觉 Transformer

    视觉 Transformer 演示了如何使用 Facebook 最新的视觉 Transformer DeiT 模型进行图像分类,以及如何转换另一个视觉 Transformer 模型并在 Android 应用程序中使用它来执行手写数字识别。

    语音识别

    语音识别 演示了如何将 Facebook AI 的 wav2vec 2.0(语音识别领域领先的模型之一)转换为 TorchScript,以及如何在 Android 应用程序中使用脚本化的模型执行语音识别。

    视频分类

    TorchVideo 演示了如何在 Android 上使用新发布的 PyTorchVideo 中提供的预训练视频分类模型,以查看视频分类结果,这些结果在视频播放时每秒更新一次,用于测试的视频、照片库中的视频,甚至是实时视频。

    PyTorch Android 教程和食谱

    在 Android 上的图像分割 DeepLabV3

    有关如何在 Android 上准备和运行 PyTorch DeepLabV3 图像分割模型的全面逐步教程。

    PyTorch Mobile 性能食谱

    在移动设备上使用 PyTorch 的性能优化的食谱列表。

    制作使用 PyTorch Android 预构建库的 Android 本机应用程序

    了解如何从头开始制作 Android 应用程序,该应用程序使用 LibTorch C++ API 并使用具有自定义 C++ 运算符的 TorchScript 模型。

    融合模块食谱

    了解如何在量化之前将 PyTorch 模块列表融合到单个模块中以减小模型大小。

    针对移动设备的量化食谱

    了解如何在不损失太多精度的同时减小模型大小并使其运行得更快。

    针对移动设备的脚本和优化

    了解如何将模型转换为 TorchScipt 并(可选)针对移动应用程序对其进行优化。

    针对 Android 模型准备食谱

    了解如何在 Android 项目中添加模型,并使用 PyTorch 库 for Android。

    从源代码构建 PyTorch Android

    在某些情况下,您可能希望使用本地构建的 PyTorch Android,例如,您可能使用另一组运算符构建自定义 LibTorch 二进制文件,或进行本地更改,或尝试最新的 PyTorch 代码。

    为此,您可以使用 ./scripts/build_pytorch_android.sh 脚本。

    git clone https://github.com/pytorch/pytorch.git
    cd pytorch
    sh ./scripts/build_pytorch_android.sh
    

    工作流程包含几个步骤

    1. 为所有 4 个 Android ABI 构建适用于 Android 的 libtorch(armeabi-v7a、arm64-v8a、x86、x86_64)

    2. 创建指向这些构建结果的符号链接: android/pytorch_android/src/main/jniLibs/${abi} 到包含输出库的目录 android/pytorch_android/src/main/cpp/libtorch_include/${abi} 到包含标头的目录。 这些目录用于构建 libpytorch_jni.so 库,作为 pytorch_android-release.aar 捆绑包的一部分,该捆绑包将在 Android 设备上加载。

    3. 最后,使用任务 assembleReleaseandroid/pytorch_android 目录中运行 gradle

    脚本要求已安装 Android SDK、Android NDK、Java SDK 和 gradle。 它们作为环境变量指定

    ANDROID_HOME - 指向 Android SDK 的路径

    ANDROID_NDK - 指向 Android NDK 的路径。 建议使用 NDK 21.x。

    GRADLE_HOME - 指向 gradle 的路径

    JAVA_HOME - 指向 JAVA JDK 的路径

    构建成功后,您应该看到结果为 aar 文件

    $ find android -type f -name *aar
    android/pytorch_android/build/outputs/aar/pytorch_android-release.aar
    android/pytorch_android_torchvision/build/outputs/aar/pytorch_android_torchvision-release.aar
    

    使用从源代码构建的或夜间构建的 PyTorch Android 库

    首先将上面构建的两个 aar 文件或从夜间构建的 PyTorch Android 存储库(在 此处此处)下载的两个 aar 文件添加到 Android 项目的 lib 文件夹中,然后在项目的 app build.gradle 文件中添加以下内容

    allprojects {
        repositories {
            flatDir {
                dirs 'libs'
            }
        }
    }
    
    dependencies {
    
        // if using the libraries built from source
        implementation(name:'pytorch_android-release', ext:'aar')
        implementation(name:'pytorch_android_torchvision-release', ext:'aar')
    
        // if using the nightly built libraries downloaded above, for example the 1.8.0-snapshot on Jan. 21, 2021
        // implementation(name:'pytorch_android-1.8.0-20210121.092759-172', ext:'aar')
        // implementation(name:'pytorch_android_torchvision-1.8.0-20210121.092817-173', ext:'aar')
    
        ...
        implementation 'com.android.support:appcompat-v7:28.0.0'
        implementation 'com.facebook.fbjni:fbjni-java-only:0.0.3'
    }
    

    我们还需要添加 aar 的所有传递依赖项。 由于 pytorch_android 依赖于 com.android.support:appcompat-v7:28.0.0androidx.appcompat:appcompat:1.2.0,因此我们需要其中一个。 (在使用 maven 依赖项的情况下,它们会从 pom.xml 自动添加)。

    使用夜间构建的 PyTorch Android 库

    除了使用从源代码构建的或从上一节链接下载的 aar 文件外,您还可以使用夜间构建的 Android PyTorch 和 TorchVision 库,方法是在您的 app build.gradle 文件中添加 maven url 和夜间库实现,如下所示

    repositories {
        maven {
            url "https://oss.sonatype.org/content/repositories/snapshots"
        }
    }
    
    dependencies {
        ...
        implementation 'org.pytorch:pytorch_android:1.8.0-SNAPSHOT'
        implementation 'org.pytorch:pytorch_android_torchvision:1.8.0-SNAPSHOT'
    }
    

    这是尝试最新 PyTorch 代码和 Android 库的最简单方法,如果您不需要进行任何本地更改。 但请注意,您可能需要使用最新 PyTorch 构建移动设备上使用的模型 - 使用最新的 PyTorch 代码或使用类似 pip install --pre torch torchvision -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html 的命令快速安装夜间构建 - 以避免在移动设备上运行模型时出现可能的模型版本不匹配错误。

    自定义构建

    为了减小二进制文件的大小,您可以使用模型所需的运算符集进行 PyTorch Android 的自定义构建。 这包括两个步骤:从您的模型准备运算符列表,使用指定的列表重新构建 PyTorch Android。

    1. 确认您的 PyTorch 版本为 1.4.0 或更高。 您可以通过检查 torch.__version__ 的值来做到这一点。

    2. 准备运算符列表

    可以使用 Python API 函数 torch.jit.export_opnames() 以 yaml 格式准备序列化 torchscript 模型的运算符列表。 要转储模型中的运算符,例如 MobileNetV2,请运行以下几行 Python 代码

    # Dump list of operators used by MobileNetV2:
    import torch, yaml
    model = torch.jit.load('MobileNetV2.pt')
    ops = torch.jit.export_opnames(model)
    with open('MobileNetV2.yaml', 'w') as output:
        yaml.dump(ops, output)
    

    3. 使用准备好的运算符列表构建 PyTorch Android。

    要使用准备好的 yaml 运算符列表构建 PyTorch Android,请在环境变量 SELECTED_OP_LIST 中指定它。 在参数中,还应指定它应该构建哪些 Android ABI;默认情况下,它将构建所有 4 个 Android ABI。

    # Build PyTorch Android library customized for MobileNetV2:
    SELECTED_OP_LIST=MobileNetV2.yaml scripts/build_pytorch_android.sh arm64-v8a
    

    构建成功后,您可以按照本教程前一节(从源代码构建 PyTorch Android)中的步骤将结果 aar 文件集成到您的 Android gradle 项目中。

    使用 PyTorch JIT 解释器

    PyTorch JIT 解释器是 1.9 之前的默认解释器(我们 PyTorch 解释器的版本,其大小效率不高)。 它将在 1.9 中继续得到支持,并且可以通过 build.gradle 使用

    repositories {
        jcenter()
    }
    
    dependencies {
        implementation 'org.pytorch:pytorch_android:1.9.0'
        implementation 'org.pytorch:pytorch_android_torchvision:1.9.0'
    }
    

    Android 教程

    观看以下 视频,PyTorch 合作伙伴工程师 Brad Heintz 将逐步介绍为 Android 项目设置 PyTorch 运行时的步骤

    PyTorch Mobile Runtime for Android

    相应的代码可在 此处 找到。

    查看我们的 移动性能食谱,其中介绍了如何优化您的模型,以及如何通过基准测试检查优化是否有帮助。

    此外,请按照此食谱了解如何 制作使用 PyTorch 预构建库的本机 Android 应用程序

    API 文档

    您可以在 Javadoc 中找到有关 PyTorch Android API 的更多详细信息。