快捷方式

使用 Flask 部署

在本食谱中,您将学习

  • 如何将训练好的 PyTorch 模型封装在 Flask 容器中,以便通过 Web API 暴露它

  • 如何将传入的 Web 请求转换为 PyTorch 张量,以供您的模型使用

  • 如何将模型的输出打包以用于 HTTP 响应

要求

您将需要一个安装了以下软件包(及其依赖项)的 Python 3 环境

  • PyTorch 1.5

  • TorchVision 0.6.0

  • Flask 1.1

可选地,要获取一些支持文件,您需要 git。

有关安装 PyTorch 和 TorchVision 的说明,请访问 pytorch.org。有关安装 Flask 的说明,请访问 Flask 网站

什么是 Flask?

Flask 是用 Python 编写的轻量级 Web 服务器。它提供了一种便捷的方式,让您可以快速设置一个 Web API,用于从您训练好的 PyTorch 模型进行预测,无论是直接使用还是作为更大系统中的 Web 服务。

设置和支持文件

我们将创建一个 Web 服务,它接收图像并将它们映射到 ImageNet 数据集中的 1000 个类别之一。为此,您将需要一个图像文件进行测试。可选地,您还可以获取一个文件,它将模型输出的类别索引映射到人类可读的类别名称。

选项 1:快速获取这两个文件

您可以通过签出 TorchServe 存储库并将它们复制到您的工作文件夹来快速获取这两个支持文件。(注意:本教程与 TorchServe 没有依赖关系 - 这只是一种快速获取文件的方法。) 从 shell 提示符发出以下命令

git clone https://github.com/pytorch/serve
cp serve/examples/image_classifier/kitten.jpg .
cp serve/examples/image_classifier/index_to_name.json .

您就得到了它们!

选项 2:自备图像

以下 Flask 服务中的 index_to_name.json 文件是可选的。您可以使用自己的图像测试您的服务 - 只要确保它是 3 色 JPEG。

构建您的 Flask 服务

Flask 服务的完整 Python 脚本显示在本食谱的末尾;您可以将其复制粘贴到您自己的 app.py 文件中。下面我们将查看各个部分以使它们的函数清晰。

导入

import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
from flask import Flask, jsonify, request

为了

  • 我们将使用来自 torchvision.models 的预训练 DenseNet 模型

  • torchvision.transforms 包含用于操作您的图像数据的工具

  • Pillow (PIL) 是我们最初用来加载图像文件的

  • 当然,我们还需要来自 flask 的类

预处理

def transform_image(infile):
    input_transforms = [transforms.Resize(255),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
            [0.229, 0.224, 0.225])]
    my_transforms = transforms.Compose(input_transforms)
    image = Image.open(infile)
    timg = my_transforms(image)
    timg.unsqueeze_(0)
    return timg

Web 请求给了我们一个图像文件,但我们的模型期望一个形状为 (N, 3, 224, 224) 的 PyTorch 张量,其中 N 是输入批次中的项目数。(我们只会有一个批次大小。)我们做的第一件事是组成一组 TorchVision 变换,这些变换会调整图像的大小和裁剪图像,将其转换为张量,然后标准化张量中的值。(有关此标准化的更多信息,请参阅 torchvision.models_ 的文档。)

之后,我们打开文件并应用变换。变换返回一个形状为 (3, 224, 224) 的张量 - 224x224 图像的 3 个颜色通道。因为我们需要将这单个图像变成一个批次,所以我们使用 unsqueeze_(0) 调用通过添加一个新的第一个维度来就地修改张量。张量包含相同的数据,但现在形状为 (1, 3, 224, 224)。

一般来说,即使您不是在处理图像数据,您也需要将来自 HTTP 请求的输入转换为 PyTorch 可以使用的张量。

推理

def get_prediction(input_tensor):
    outputs = model.forward(input_tensor)
    _, y_hat = outputs.max(1)
    prediction = y_hat.item()
    return prediction

推理本身是最简单的部分:当我们将输入张量传递给模型时,我们得到一个张量值,这些值代表模型估计的图像属于特定类的可能性。 max() 调用查找具有最大似然值的类,并返回该值和 ImageNet 类索引。最后,我们使用 item() 调用从包含它的张量中提取该类索引,并返回它。

后处理

def render_prediction(prediction_idx):
    stridx = str(prediction_idx)
    class_name = 'Unknown'
    if img_class_map is not None:
        if stridx in img_class_map is not None:
            class_name = img_class_map[stridx][1]

    return prediction_idx, class_name

render_prediction() 方法将预测的类索引映射到人类可读的类标签。在从您的模型中获得预测之后,通常执行后处理以使预测准备好用于人类使用或其他软件。

运行完整的 Flask 应用程序

将以下内容粘贴到名为 app.py 的文件中

import io
import json
import os

import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
from flask import Flask, jsonify, request


app = Flask(__name__)
model = models.densenet121(pretrained=True)               # Trained on 1000 classes from ImageNet
model.eval()                                              # Turns off autograd



img_class_map = None
mapping_file_path = 'index_to_name.json'                  # Human-readable names for Imagenet classes
if os.path.isfile(mapping_file_path):
    with open (mapping_file_path) as f:
        img_class_map = json.load(f)



# Transform input into the form our model expects
def transform_image(infile):
    input_transforms = [transforms.Resize(255),           # We use multiple TorchVision transforms to ready the image
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],       # Standard normalization for ImageNet model input
            [0.229, 0.224, 0.225])]
    my_transforms = transforms.Compose(input_transforms)
    image = Image.open(infile)                            # Open the image file
    timg = my_transforms(image)                           # Transform PIL image to appropriately-shaped PyTorch tensor
    timg.unsqueeze_(0)                                    # PyTorch models expect batched input; create a batch of 1
    return timg


# Get a prediction
def get_prediction(input_tensor):
    outputs = model.forward(input_tensor)                 # Get likelihoods for all ImageNet classes
    _, y_hat = outputs.max(1)                             # Extract the most likely class
    prediction = y_hat.item()                             # Extract the int value from the PyTorch tensor
    return prediction

# Make the prediction human-readable
def render_prediction(prediction_idx):
    stridx = str(prediction_idx)
    class_name = 'Unknown'
    if img_class_map is not None:
        if stridx in img_class_map is not None:
            class_name = img_class_map[stridx][1]

    return prediction_idx, class_name


@app.route('/', methods=['GET'])
def root():
    return jsonify({'msg' : 'Try POSTing to the /predict endpoint with an RGB image attachment'})


@app.route('/predict', methods=['POST'])
def predict():
    if request.method == 'POST':
        file = request.files['file']
        if file is not None:
            input_tensor = transform_image(file)
            prediction_idx = get_prediction(input_tensor)
            class_id, class_name = render_prediction(prediction_idx)
            return jsonify({'class_id': class_id, 'class_name': class_name})


if __name__ == '__main__':
    app.run()

要从您的 shell 提示符启动服务器,请发出以下命令

FLASK_APP=app.py flask run

默认情况下,您的 Flask 服务器在端口 5000 上监听。服务器运行后,打开另一个终端窗口,测试您的新推理服务器

curl -X POST -H "Content-Type: multipart/form-data" http://localhost:5000/predict -F "[email protected]"

如果一切设置正确,您应该收到类似于以下内容的响应

{"class_id":285,"class_name":"Egyptian_cat"}

重要资源

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源