• 教程 >
  • 通过 Flask 使用 REST API 在 Python 中部署 PyTorch
快捷方式

通过 Flask 使用 REST API 在 Python 中部署 PyTorch

创建于: 2019 年 7 月 3 日 | 最后更新于: 2024 年 1 月 19 日 | 最后验证于: 2024 年 11 月 5 日

作者: Avinash Sajjanshetty

在本教程中,我们将使用 Flask 部署一个 PyTorch 模型,并公开一个 REST API 用于模型推理。特别是,我们将部署一个用于图像检测的预训练 DenseNet 121 模型。

提示

此处使用的所有代码均根据 MIT 许可发布,可在 Github 上获取。

这是关于在生产环境中部署 PyTorch 模型系列教程的第一部分。以这种方式使用 Flask 是开始服务你的 PyTorch 模型迄今为止最简单的方法,但它不适用于具有高性能要求的用例。对于这种情况,

API 定义

我们将首先定义 API 端点、请求和响应类型。我们的 API 端点将是 /predict,它接受带有包含图像的 file 参数的 HTTP POST 请求。响应将是包含预测结果的 JSON 格式响应。

{"class_id": "n02124075", "class_name": "Egyptian_cat"}

依赖项

运行以下命令安装所需的依赖项

pip install Flask==2.0.1 torchvision==0.10.0

简单的 Web 服务器

以下是一个简单的 Web 服务器,摘自 Flask 的文档

from flask import Flask
app = Flask(__name__)


@app.route('/')
def hello():
    return 'Hello World!'

我们还将更改响应类型,使其返回包含 ImageNet 类 ID 和名称的 JSON 响应。更新后的 app.py 文件现在如下所示

from flask import Flask, jsonify
app = Flask(__name__)

@app.route('/predict', methods=['POST'])
def predict():
    return jsonify({'class_id': 'IMAGE_NET_XXX', 'class_name': 'Cat'})

推理

在接下来的部分中,我们将重点编写推理代码。这将包括两个部分:第一部分准备图像,使其可以馈送到 DenseNet;第二部分编写代码以从模型获取实际预测。

准备图像

DenseNet 模型要求图像是大小为 224 x 224 的 3 通道 RGB 图像。我们还将使用所需的均值和标准差值对图像张量进行归一化。你可以在此处阅读更多相关信息。

我们将使用 torchvision 库中的 transforms 并构建一个变换管道,它会根据需要变换我们的图像。你可以在此处阅读更多关于变换的信息。

import io

import torchvision.transforms as transforms
from PIL import Image

def transform_image(image_bytes):
    my_transforms = transforms.Compose([transforms.Resize(255),
                                        transforms.CenterCrop(224),
                                        transforms.ToTensor(),
                                        transforms.Normalize(
                                            [0.485, 0.456, 0.406],
                                            [0.229, 0.224, 0.225])])
    image = Image.open(io.BytesIO(image_bytes))
    return my_transforms(image).unsqueeze(0)

上述方法接收字节格式的图像数据,应用一系列变换并返回一个张量。要测试上述方法,请以字节模式读取图像文件(首先将 ../_static/img/sample_file.jpeg 替换为你计算机上的实际文件路径),然后查看是否能获得张量。

with open("../_static/img/sample_file.jpeg", 'rb') as f:
    image_bytes = f.read()
    tensor = transform_image(image_bytes=image_bytes)
    print(tensor)

预测

现在将使用预训练的 DenseNet 121 模型来预测图像类别。我们将使用 torchvision 库中的一个模型,加载模型并进行推理。虽然本示例中使用的是预训练模型,但你可以将此方法应用于你自己的模型。请在此教程中了解更多关于加载模型的信息。

from torchvision import models

# Make sure to set `weights` as `'IMAGENET1K_V1'` to use the pretrained weights:
model = models.densenet121(weights='IMAGENET1K_V1')
# Since we are using our model only for inference, switch to `eval` mode:
model.eval()


def get_prediction(image_bytes):
    tensor = transform_image(image_bytes=image_bytes)
    outputs = model.forward(tensor)
    _, y_hat = outputs.max(1)
    return y_hat

张量 y_hat 将包含预测类别 ID 的索引。但是,我们需要一个人类可读的类别名称。为此,我们需要类别 ID 到名称的映射。将此文件下载为 imagenet_class_index.json 并记住保存位置(或者,如果你完全按照本教程的步骤操作,请将其保存在 tutorials/_static 中)。此文件包含 ImageNet 类别 ID 到 ImageNet 类别名称的映射。我们将加载此 JSON 文件并获取预测索引对应的类别名称。

import json

imagenet_class_index = json.load(open('../_static/imagenet_class_index.json'))

def get_prediction(image_bytes):
    tensor = transform_image(image_bytes=image_bytes)
    outputs = model.forward(tensor)
    _, y_hat = outputs.max(1)
    predicted_idx = str(y_hat.item())
    return imagenet_class_index[predicted_idx]

在使用 imagenet_class_index 字典之前,我们首先将张量值转换为字符串值,因为 imagenet_class_index 字典中的键是字符串。我们将测试上述方法

with open("../_static/img/sample_file.jpeg", 'rb') as f:
    image_bytes = f.read()
    print(get_prediction(image_bytes=image_bytes))

你应该会得到如下所示的响应

['n02124075', 'Egyptian_cat']

数组中的第一项是 ImageNet 类别 ID,第二项是人类可读的名称。

将模型集成到我们的 API 服务器中

在最后这部分,我们将把模型添加到 Flask API 服务器中。由于我们的 API 服务器应该接收图像文件,我们将更新 predict 方法以从请求中读取文件。

from flask import request

@app.route('/predict', methods=['POST'])
def predict():
    if request.method == 'POST':
        # we will get the file from the request
        file = request.files['file']
        # convert that to bytes
        img_bytes = file.read()
        class_id, class_name = get_prediction(image_bytes=img_bytes)
        return jsonify({'class_id': class_id, 'class_name': class_name})
import io
import json

from torchvision import models
import torchvision.transforms as transforms
from PIL import Image
from flask import Flask, jsonify, request


app = Flask(__name__)
imagenet_class_index = json.load(open('<PATH/TO/.json/FILE>/imagenet_class_index.json'))
model = models.densenet121(weights='IMAGENET1K_V1')
model.eval()


def transform_image(image_bytes):
    my_transforms = transforms.Compose([transforms.Resize(255),
                                        transforms.CenterCrop(224),
                                        transforms.ToTensor(),
                                        transforms.Normalize(
                                            [0.485, 0.456, 0.406],
                                            [0.229, 0.224, 0.225])])
    image = Image.open(io.BytesIO(image_bytes))
    return my_transforms(image).unsqueeze(0)


def get_prediction(image_bytes):
    tensor = transform_image(image_bytes=image_bytes)
    outputs = model.forward(tensor)
    _, y_hat = outputs.max(1)
    predicted_idx = str(y_hat.item())
    return imagenet_class_index[predicted_idx]


@app.route('/predict', methods=['POST'])
def predict():
    if request.method == 'POST':
        file = request.files['file']
        img_bytes = file.read()
        class_id, class_name = get_prediction(image_bytes=img_bytes)
        return jsonify({'class_id': class_id, 'class_name': class_name})


if __name__ == '__main__':
    app.run()
FLASK_ENV=development FLASK_APP=app.py flask run

库向我们的应用发送 POST 请求

import requests

resp = requests.post("http://localhost:5000/predict",
                     files={"file": open('<PATH/TO/.jpg/FILE>/cat.jpg','rb')})

现在打印 resp.json() 将显示如下内容

{"class_id": "n02124075", "class_name": "Egyptian_cat"}

我们编写的服务器非常简单,可能无法满足你的生产应用所需的一切。因此,这里有一些可以改进它的地方

  • 端点 /predict 假设请求中始终存在图像文件。这对于所有请求可能不成立。用户可能会使用不同的参数发送图像,或者根本不发送图像。

  • 用户也可能发送非图像类型的文件。由于我们没有处理错误,这将导致服务器崩溃。添加一个显式的错误处理路径来抛出异常,将使我们能够更好地处理不良输入。

  • 即使模型能够识别大量图像类别,它也可能无法识别所有图像。增强实现来处理模型无法识别图像中任何内容的情况。

  • 我们在开发模式下运行 Flask 服务器,这不适合在生产环境中部署。你可以查看此教程,了解如何在生产环境中部署 Flask 服务器。

  • 你还可以通过创建一个包含表单的页面来添加用户界面 (UI),该表单接收图像并显示预测结果。请查看一个类似项目的演示及其源代码

  • 在本教程中,我们只展示了如何构建一个一次只能返回单个图像预测结果的服务。我们可以修改服务使其能够一次返回多个图像的预测结果。此外,service-streamer 库会自动将发送到你的服务的请求排队,并将它们采样成可以输入到模型中的 mini-batches。你可以查看此教程

  • 最后,我们鼓励你查看页面顶部链接到的其他关于部署 PyTorch 模型的教程。

脚本总运行时间: ( 0 分 0.000 秒)

由 Sphinx-Gallery 生成的图库

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源