使用 Flask 进行部署¶
创建于: 2020 年 5 月 4 日 | 最后更新于: 2021 年 9 月 15 日 | 最后验证于: 未验证
在本代码示例中,您将学习
如何将您训练好的 PyTorch 模型封装在 Flask 容器中,并通过 Web API 公开它
如何将传入的 Web 请求转换为用于模型的 PyTorch 张量
如何将模型的输出打包为 HTTP 响应
要求¶
您需要一个安装了以下软件包(及其依赖项)的 Python 3 环境
PyTorch 1.5
TorchVision 0.6.0
Flask 1.1
可选地,为了获取一些支持文件,您需要 git。
PyTorch 和 TorchVision 的安装说明可在 pytorch.org 找到。Flask 的安装说明可在 Flask 网站上找到。
什么是 Flask?¶
Flask 是一个用 Python 编写的轻量级 Web 服务器。它为您提供了一种便捷的方式,可以快速设置一个 Web API,用于从您训练好的 PyTorch 模型获取预测结果,无论是直接使用,还是作为更大系统中的 Web 服务。
设置和支持文件¶
我们将创建一个 Web 服务,该服务接收图像,并将其映射到 ImageNet 数据集的 1000 个类别之一。为此,您需要一个用于测试的图像文件。可选地,您还可以获取一个文件,该文件将模型输出的类别索引映射到人类可读的类别名称。
选项 1:快速获取两个文件¶
您可以通过检出 TorchServe 仓库并将它们复制到您的工作文件夹中来快速获取这两个支持文件。(注意:本教程不依赖于 TorchServe - 这只是一种获取文件的快捷方式。)从您的 shell 提示符中发出以下命令
git clone https://github.com/pytorch/serve
cp serve/examples/image_classifier/kitten.jpg .
cp serve/examples/image_classifier/index_to_name.json .
这样您就得到了它们!
选项 2:使用您自己的图像¶
下面的 Flask 服务中,index_to_name.json
文件是可选的。您可以使用自己的图像测试您的服务 - 只需确保它是 3 色 JPEG。
构建您的 Flask 服务¶
完整的 Flask 服务 Python 脚本显示在本代码示例的末尾;您可以将其复制粘贴到您自己的 app.py
文件中。下面我们将查看各个部分以明确其功能。
导入¶
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
from flask import Flask, jsonify, request
按顺序
我们将使用来自
torchvision.models
的预训练 DenseNet 模型torchvision.transforms
包含用于处理图像数据的工具Pillow (
PIL
) 是我们最初用于加载图像文件的库当然,我们还需要来自
flask
的类
预处理¶
def transform_image(infile):
input_transforms = [transforms.Resize(255),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])]
my_transforms = transforms.Compose(input_transforms)
image = Image.open(infile)
timg = my_transforms(image)
timg.unsqueeze_(0)
return timg
Web 请求给了我们一个图像文件,但我们的模型期望一个形状为 (N, 3, 224, 224) 的 PyTorch 张量,其中 N 是输入批量中的项目数量。(我们只需要一个批量大小为 1。)我们首先做的是组合一组 TorchVision 变换,这些变换会调整图像大小、裁剪图像、将其转换为张量,然后对张量中的值进行归一化。(有关此归一化的更多信息,请参阅 torchvision.models_
的文档。)
之后,我们打开文件并应用变换。这些变换返回一个形状为 (3, 224, 224) 的张量 - 一个 224x224 图像的 3 个颜色通道。因为我们需要将这个单张图像转换为一个批量,所以我们使用 unsqueeze_(0)
调用通过添加一个新的第一维度来原地修改张量。该张量包含相同的数据,但现在形状为 (1, 3, 224, 224)。
通常,即使您不处理图像数据,您也需要将来自 HTTP 请求的输入转换为 PyTorch 可以使用的张量。
推理¶
def get_prediction(input_tensor):
outputs = model.forward(input_tensor)
_, y_hat = outputs.max(1)
prediction = y_hat.item()
return prediction
推理本身是最简单的部分:当我们把输入张量传递给模型时,我们会得到一个值张量,这些值代表模型估计该图像属于某个特定类别的可能性。max()
调用找到具有最大似然值的类别,并返回该值以及 ImageNet 类别索引。最后,我们使用 item()
调用从包含该类别索引的张量中提取它,并返回它。
后处理¶
def render_prediction(prediction_idx):
stridx = str(prediction_idx)
class_name = 'Unknown'
if img_class_map is not None:
if stridx in img_class_map is not None:
class_name = img_class_map[stridx][1]
return prediction_idx, class_name
render_prediction()
方法将预测的类别索引映射到人类可读的类别标签。通常,在从模型获得预测后,会进行后处理,以使预测结果可供人类消费或供其他软件使用。
运行完整的 Flask 应用¶
将以下内容粘贴到名为 app.py
的文件中
import io
import json
import os
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
from flask import Flask, jsonify, request
app = Flask(__name__)
model = models.densenet121(pretrained=True) # Trained on 1000 classes from ImageNet
model.eval() # Turns off autograd
img_class_map = None
mapping_file_path = 'index_to_name.json' # Human-readable names for Imagenet classes
if os.path.isfile(mapping_file_path):
with open (mapping_file_path) as f:
img_class_map = json.load(f)
# Transform input into the form our model expects
def transform_image(infile):
input_transforms = [transforms.Resize(255), # We use multiple TorchVision transforms to ready the image
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], # Standard normalization for ImageNet model input
[0.229, 0.224, 0.225])]
my_transforms = transforms.Compose(input_transforms)
image = Image.open(infile) # Open the image file
timg = my_transforms(image) # Transform PIL image to appropriately-shaped PyTorch tensor
timg.unsqueeze_(0) # PyTorch models expect batched input; create a batch of 1
return timg
# Get a prediction
def get_prediction(input_tensor):
outputs = model.forward(input_tensor) # Get likelihoods for all ImageNet classes
_, y_hat = outputs.max(1) # Extract the most likely class
prediction = y_hat.item() # Extract the int value from the PyTorch tensor
return prediction
# Make the prediction human-readable
def render_prediction(prediction_idx):
stridx = str(prediction_idx)
class_name = 'Unknown'
if img_class_map is not None:
if stridx in img_class_map is not None:
class_name = img_class_map[stridx][1]
return prediction_idx, class_name
@app.route('/', methods=['GET'])
def root():
return jsonify({'msg' : 'Try POSTing to the /predict endpoint with an RGB image attachment'})
@app.route('/predict', methods=['POST'])
def predict():
if request.method == 'POST':
file = request.files['file']
if file is not None:
input_tensor = transform_image(file)
prediction_idx = get_prediction(input_tensor)
class_id, class_name = render_prediction(prediction_idx)
return jsonify({'class_id': class_id, 'class_name': class_name})
if __name__ == '__main__':
app.run()
要从您的 shell 提示符启动服务器,请发出以下命令
FLASK_APP=app.py flask run
默认情况下,您的 Flask 服务器正在侦听端口 5000。服务器运行后,打开另一个终端窗口,测试您的新推理服务器
curl -X POST -H "Content-Type: multipart/form-data" http://localhost:5000/predict -F "file=@kitten.jpg"
如果一切设置正确,您应该会收到类似于以下内容的响应
{"class_id":285,"class_name":"Egyptian_cat"}
重要资源¶
pytorch.org 获取安装说明以及更多文档和教程