• 教程 >
  • 使用 Flask 通过 REST API 在 Python 中部署 PyTorch
快捷方式

使用 Flask 通过 REST API 在 Python 中部署 PyTorch

作者: Avinash Sajjanshetty

在本教程中,我们将使用 Flask 部署 PyTorch 模型,并公开一个用于模型推理的 REST API。特别是,我们将部署一个预训练的 DenseNet 121 模型,用于检测图像。

提示

此处使用的所有代码均在 MIT 许可证下发布,并在 Github 上提供。

这代表了在生产环境中部署 PyTorch 模型的一系列教程中的第一个。以这种方式使用 Flask 是开始服务您的 PyTorch 模型最简单的方法,但它不适用于对性能要求很高的用例。为此

API 定义

首先,我们将定义我们的 API 端点、请求和响应类型。我们的 API 端点将位于 /predict,它接收包含 file 参数的 HTTP POST 请求,该参数包含图像。响应将是包含预测的 JSON 响应。

{"class_id": "n02124075", "class_name": "Egyptian_cat"}

依赖项

运行以下命令安装所需的依赖项:

pip install Flask==2.0.1 torchvision==0.10.0

简单的 Web 服务器

下面是一个简单的 Web 服务器,取自 Flask 的文档:

from flask import Flask
app = Flask(__name__)


@app.route('/')
def hello():
    return 'Hello World!'

我们还将更改响应类型,使其返回包含 ImageNet 类 ID 和名称的 JSON 响应。更新后的 app.py 文件如下所示:

from flask import Flask, jsonify
app = Flask(__name__)

@app.route('/predict', methods=['POST'])
def predict():
    return jsonify({'class_id': 'IMAGE_NET_XXX', 'class_name': 'Cat'})

推理

在接下来的部分中,我们将专注于编写推理代码。这将涉及两个部分,一部分是我们准备图像以便将其馈送到 DenseNet,接下来,我们将编写代码以从模型中获取实际预测。

准备图像

DenseNet 模型要求图像为大小为 224 x 224 的 3 通道 RGB 图像。我们还将使用所需的均值和标准差值对图像张量进行归一化。您可以在 此处了解更多信息。

我们将使用 torchvision 库中的 transforms 并构建一个转换管道,该管道根据需要转换我们的图像。您可以在 此处了解更多关于转换的信息。

import io

import torchvision.transforms as transforms
from PIL import Image

def transform_image(image_bytes):
    my_transforms = transforms.Compose([transforms.Resize(255),
                                        transforms.CenterCrop(224),
                                        transforms.ToTensor(),
                                        transforms.Normalize(
                                            [0.485, 0.456, 0.406],
                                            [0.229, 0.224, 0.225])])
    image = Image.open(io.BytesIO(image_bytes))
    return my_transforms(image).unsqueeze(0)

上述方法获取以字节为单位的图像数据,应用一系列转换并返回一个张量。要测试上述方法,请以字节模式读取图像文件(首先将 ../_static/img/sample_file.jpeg 替换为您计算机上文件的实际路径),并查看是否获得了张量。

with open("../_static/img/sample_file.jpeg", 'rb') as f:
    image_bytes = f.read()
    tensor = transform_image(image_bytes=image_bytes)
    print(tensor)

预测

现在将使用预训练的 DenseNet 121 模型来预测图像类别。我们将使用 torchvision 库中的一个模型,加载模型并进行推理。虽然在本例中我们将使用预训练模型,但您可以将此方法用于您自己的模型。有关加载模型的更多信息,请参阅本 教程

from torchvision import models

# Make sure to set `weights` as `'IMAGENET1K_V1'` to use the pretrained weights:
model = models.densenet121(weights='IMAGENET1K_V1')
# Since we are using our model only for inference, switch to `eval` mode:
model.eval()


def get_prediction(image_bytes):
    tensor = transform_image(image_bytes=image_bytes)
    outputs = model.forward(tensor)
    _, y_hat = outputs.max(1)
    return y_hat

张量 y_hat 将包含预测的类 ID 的索引。但是,我们需要一个人类可读的类名。为此,我们需要一个类 ID 到名称的映射。将 此文件 下载为 imagenet_class_index.json 并记住您将其保存到的位置(或者,如果您按照本教程中的精确步骤操作,请将其保存到 tutorials/_static)。此文件包含 ImageNet 类 ID 到 ImageNet 类名称的映射。我们将加载此 JSON 文件并获取预测索引的类名。

import json

imagenet_class_index = json.load(open('../_static/imagenet_class_index.json'))

def get_prediction(image_bytes):
    tensor = transform_image(image_bytes=image_bytes)
    outputs = model.forward(tensor)
    _, y_hat = outputs.max(1)
    predicted_idx = str(y_hat.item())
    return imagenet_class_index[predicted_idx]

在使用 imagenet_class_index 字典之前,我们将先将张量值转换为字符串值,因为 imagenet_class_index 字典中的键是字符串。我们将测试上述方法:

with open("../_static/img/sample_file.jpeg", 'rb') as f:
    image_bytes = f.read()
    print(get_prediction(image_bytes=image_bytes))

您应该会收到如下响应:

['n02124075', 'Egyptian_cat']

数组中的第一个项目是 ImageNet 类 ID,第二个项目是人类可读的名称。

将模型集成到我们的 API 服务器中

在最后一部分中,我们将模型添加到 Flask API 服务器中。由于我们的 API 服务器应该接收图像文件,因此我们将更新我们的 predict 方法以从请求中读取文件:

from flask import request

@app.route('/predict', methods=['POST'])
def predict():
    if request.method == 'POST':
        # we will get the file from the request
        file = request.files['file']
        # convert that to bytes
        img_bytes = file.read()
        class_id, class_name = get_prediction(image_bytes=img_bytes)
        return jsonify({'class_id': class_id, 'class_name': class_name})
import io
import json

from torchvision import models
import torchvision.transforms as transforms
from PIL import Image
from flask import Flask, jsonify, request


app = Flask(__name__)
imagenet_class_index = json.load(open('<PATH/TO/.json/FILE>/imagenet_class_index.json'))
model = models.densenet121(weights='IMAGENET1K_V1')
model.eval()


def transform_image(image_bytes):
    my_transforms = transforms.Compose([transforms.Resize(255),
                                        transforms.CenterCrop(224),
                                        transforms.ToTensor(),
                                        transforms.Normalize(
                                            [0.485, 0.456, 0.406],
                                            [0.229, 0.224, 0.225])])
    image = Image.open(io.BytesIO(image_bytes))
    return my_transforms(image).unsqueeze(0)


def get_prediction(image_bytes):
    tensor = transform_image(image_bytes=image_bytes)
    outputs = model.forward(tensor)
    _, y_hat = outputs.max(1)
    predicted_idx = str(y_hat.item())
    return imagenet_class_index[predicted_idx]


@app.route('/predict', methods=['POST'])
def predict():
    if request.method == 'POST':
        file = request.files['file']
        img_bytes = file.read()
        class_id, class_name = get_prediction(image_bytes=img_bytes)
        return jsonify({'class_id': class_id, 'class_name': class_name})


if __name__ == '__main__':
    app.run()
FLASK_ENV=development FLASK_APP=app.py flask run

库发送 POST 请求到我们的应用程序。

import requests

resp = requests.post("http://localhost:5000/predict",
                     files={"file": open('<PATH/TO/.jpg/FILE>/cat.jpg','rb')})

打印 resp.json() 现在将显示以下内容:

{"class_id": "n02124075", "class_name": "Egyptian_cat"}

我们编写的服务器非常简单,可能无法满足生产应用程序的所有需求。因此,以下是一些可以改进它的方法:

  • 端点 /predict 假设请求中始终存在图像文件。这对于所有请求可能并不成立。我们的用户可能会使用不同的参数发送图像,或者根本不发送图像。

  • 用户也可能发送非图像类型的文件。由于我们没有处理错误,这将导致服务器崩溃。添加一个显式的错误处理路径,该路径将抛出异常,这将使我们能够更好地处理错误的输入。

  • 尽管模型可以识别大量类别的图像,但它可能无法识别所有图像。增强实现以处理模型无法识别图像中任何内容的情况。

  • 我们以开发模式运行 Flask 服务器,这并不适合在生产环境中部署。您可以查看 本教程,了解如何在生产环境中部署 Flask 服务器。

  • 您还可以通过创建一个带有表单的页面来添加 UI,该表单接收图像并显示预测。查看类似项目的 演示 及其 源代码

  • 在本教程中,我们仅展示了如何构建一个服务,该服务可以一次返回单个图像的预测。我们可以修改我们的服务以能够一次返回多个图像的预测。此外,service-streamer 库会自动将请求排队到您的服务并将其采样到可以馈送到模型的小批量中。您可以查看 本教程

  • 最后,我们鼓励您查看我们关于部署 PyTorch 模型的其他教程,这些教程链接在页面顶部。

脚本的总运行时间:(0 分钟 0.000 秒)

由 Sphinx-Gallery 生成的图库

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源