注意
点击这里下载完整的示例代码
通过 Flask REST API 在 Python 中部署 PyTorch¶
创建于:2019 年 7 月 3 日 | 最后更新:2024 年 1 月 19 日 | 最后验证:2024 年 11 月 05 日
在本教程中,我们将使用 Flask 部署 PyTorch 模型,并公开用于模型推理的 REST API。 特别是,我们将部署一个预训练的 DenseNet 121 模型,用于检测图像。
提示
此处使用的所有代码均在 MIT 许可证下发布,并在 Github 上提供。
这是关于在生产环境中部署 PyTorch 模型的一系列教程中的第一个。 以这种方式使用 Flask 是启动服务您的 PyTorch 模型的最简单方法,但它不适用于具有高性能要求的用例。 对于这种情况
如果您已经熟悉 TorchScript,可以直接跳到我们的 在 C++ 中加载 TorchScript 模型 教程。
如果您首先需要复习 TorchScript,请查看我们的 TorchScript 简介 教程。
API 定义¶
我们将首先定义我们的 API 端点、请求和响应类型。 我们的 API 端点将在 /predict
,它接受带有 file
参数的 HTTP POST 请求,该参数包含图像。 响应将是包含预测的 JSON 响应
{"class_id": "n02124075", "class_name": "Egyptian_cat"}
简单 Web 服务器¶
以下是一个简单的 Web 服务器,取自 Flask 的文档
from flask import Flask
app = Flask(__name__)
@app.route('/')
def hello():
return 'Hello World!'
我们还将更改响应类型,使其返回包含 ImageNet 类 ID 和名称的 JSON 响应。 更新后的 app.py
文件现在将是
from flask import Flask, jsonify
app = Flask(__name__)
@app.route('/predict', methods=['POST'])
def predict():
return jsonify({'class_id': 'IMAGE_NET_XXX', 'class_name': 'Cat'})
推理¶
在接下来的部分中,我们将重点编写推理代码。 这将涉及两个部分,一部分是我们准备图像,以便可以将其馈送到 DenseNet,接下来,我们将编写代码以从模型中获取实际预测。
准备图像¶
DenseNet 模型要求图像为 3 通道 RGB 图像,大小为 224 x 224。 我们还将使用所需的均值和标准差值对图像张量进行归一化。 您可以在 此处 阅读更多相关信息。
我们将使用 torchvision
库中的 transforms
并构建一个变换管道,该管道按需变换我们的图像。 您可以在 此处 阅读更多关于变换的信息。
import io
import torchvision.transforms as transforms
from PIL import Image
def transform_image(image_bytes):
my_transforms = transforms.Compose([transforms.Resize(255),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
[0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])])
image = Image.open(io.BytesIO(image_bytes))
return my_transforms(image).unsqueeze(0)
上述方法以字节为单位获取图像数据,应用一系列变换并返回张量。 为了测试上述方法,以字节模式读取图像文件(首先将 ../_static/img/sample_file.jpeg 替换为您计算机上文件的实际路径),看看您是否得到一个张量
with open("../_static/img/sample_file.jpeg", 'rb') as f:
image_bytes = f.read()
tensor = transform_image(image_bytes=image_bytes)
print(tensor)
预测¶
现在将使用预训练的 DenseNet 121 模型来预测图像类别。 我们将使用 torchvision
库中的一个模型,加载模型并获得推理结果。 虽然在本示例中我们将使用预训练模型,但您可以对自己的模型使用相同的方法。 有关加载模型的更多信息,请参阅此 教程。
from torchvision import models
# Make sure to set `weights` as `'IMAGENET1K_V1'` to use the pretrained weights:
model = models.densenet121(weights='IMAGENET1K_V1')
# Since we are using our model only for inference, switch to `eval` mode:
model.eval()
def get_prediction(image_bytes):
tensor = transform_image(image_bytes=image_bytes)
outputs = model.forward(tensor)
_, y_hat = outputs.max(1)
return y_hat
张量 y_hat
将包含预测的类 ID 的索引。 但是,我们需要人类可读的类名称。 为此,我们需要类 ID 到名称的映射。 下载 此文件 作为 imagenet_class_index.json
并记住您保存它的位置(或者,如果您正在遵循本教程中的确切步骤,请将其保存在 tutorials/_static 中)。 此文件包含 ImageNet 类 ID 到 ImageNet 类名称的映射。 我们将加载此 JSON 文件并获取预测索引的类名称。
import json
imagenet_class_index = json.load(open('../_static/imagenet_class_index.json'))
def get_prediction(image_bytes):
tensor = transform_image(image_bytes=image_bytes)
outputs = model.forward(tensor)
_, y_hat = outputs.max(1)
predicted_idx = str(y_hat.item())
return imagenet_class_index[predicted_idx]
在使用 imagenet_class_index
字典之前,首先我们将张量值转换为字符串值,因为 imagenet_class_index
字典中的键是字符串。 我们将测试我们的上述方法
with open("../_static/img/sample_file.jpeg", 'rb') as f:
image_bytes = f.read()
print(get_prediction(image_bytes=image_bytes))
您应该收到这样的响应
['n02124075', 'Egyptian_cat']
数组中的第一个项目是 ImageNet 类 ID,第二个项目是人类可读的名称。
- 将模型集成到我们的 API 服务器中
在最后一部分中,我们将我们的模型添加到我们的 Flask API 服务器中。 由于我们的 API 服务器应该接受图像文件,我们将更新我们的
predict
方法以从请求中读取文件from flask import request @app.route('/predict', methods=['POST']) def predict(): if request.method == 'POST': # we will get the file from the request file = request.files['file'] # convert that to bytes img_bytes = file.read() class_id, class_name = get_prediction(image_bytes=img_bytes) return jsonify({'class_id': class_id, 'class_name': class_name})
import io import json from torchvision import models import torchvision.transforms as transforms from PIL import Image from flask import Flask, jsonify, request app = Flask(__name__) imagenet_class_index = json.load(open('<PATH/TO/.json/FILE>/imagenet_class_index.json')) model = models.densenet121(weights='IMAGENET1K_V1') model.eval() def transform_image(image_bytes): my_transforms = transforms.Compose([transforms.Resize(255), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize( [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) image = Image.open(io.BytesIO(image_bytes)) return my_transforms(image).unsqueeze(0) def get_prediction(image_bytes): tensor = transform_image(image_bytes=image_bytes) outputs = model.forward(tensor) _, y_hat = outputs.max(1) predicted_idx = str(y_hat.item()) return imagenet_class_index[predicted_idx] @app.route('/predict', methods=['POST']) def predict(): if request.method == 'POST': file = request.files['file'] img_bytes = file.read() class_id, class_name = get_prediction(image_bytes=img_bytes) return jsonify({'class_id': class_id, 'class_name': class_name}) if __name__ == '__main__': app.run()
FLASK_ENV=development FLASK_APP=app.py flask run
库以向我们的应用程序发送 POST 请求
import requests resp = requests.post("https://127.0.0.1:5000/predict", files={"file": open('<PATH/TO/.jpg/FILE>/cat.jpg','rb')})
打印 resp.json() 现在将显示以下内容
{"class_id": "n02124075", "class_name": "Egyptian_cat"}
我们编写的服务器非常简单,可能无法完成您的生产应用程序所需的一切。 因此,以下是一些您可以做的事情来使其更好
端点
/predict
假定请求中始终会有一个图像文件。 这可能不适用于所有请求。 我们的用户可能会发送带有不同参数的图像,或者根本不发送图像。用户也可能发送非图像类型的文件。 由于我们没有处理错误,这将破坏我们的服务器。 添加显式的错误处理路径以抛出异常将使我们能够更好地处理错误的输入
即使该模型可以识别大量图像类别,它也可能无法识别所有图像。 增强实现以处理模型无法识别图像中任何内容的情况。
我们在开发模式下运行 Flask 服务器,这不适合在生产环境中部署。 您可以查看 本教程,了解如何在生产环境中部署 Flask 服务器。
在本教程中,我们仅展示了如何构建一个一次只能返回单个图像预测的服务。 我们可以修改我们的服务,使其能够一次返回多个图像的预测。 此外,service-streamer 库会自动将请求排队到您的服务中,并将它们采样到可以馈送到您的模型中的小批量中。 您可以查看 本教程。
最后,我们鼓励您查看页面顶部链接的关于部署 PyTorch 模型的其他教程。
脚本的总运行时间: (0 分 0.000 秒)