快捷方式

自定义服务

本文档的内容

自定义处理程序

通过编写一个 Python 脚本(在您使用模型归档器时与模型一起打包)来自定义 TorchServe 的行为。TorchServe 在运行时会执行此代码。

提供一个自定义脚本以

  • 初始化模型实例

  • 在将输入数据发送到模型进行推理或 Captum 解释之前对其进行预处理

  • 自定义模型如何被调用以进行推理或解释

  • 在将响应发回之前对模型的输出进行后处理

以下是适用于所有类型的自定义处理程序的内容

  • data - 来自传入请求的输入数据

  • context - 是 TorchServe context。您可以使用以下信息进行自定义模型名称、模型目录、清单、批次大小、gpu 等。

从 BaseHandler 开始!

BaseHandler 实现您需要的大部分功能。您可以从中派生一个新类,如示例和默认处理程序所示。大多数情况下,您只需要覆盖 preprocesspostprocess

使用 module 级入口点的自定义处理程序

自定义处理程序文件必须定义一个模块级函数,该函数充当执行的入口点。该函数可以具有任何名称,但必须接受以下参数并返回预测结果。

入口点函数的签名为

# Create model object
model = None

def entry_point_function_name(data, context):
    """
    Works on data and context to create model object or process inference request.
    Following sample demonstrates how model object can be initialized for jit mode.
    Similarly you can do it for eager mode models.
    :param data: Input data for prediction
    :param context: context contains model server system properties
    :return: prediction output
    """
    global model

    if not data:
        manifest = context.manifest

        properties = context.system_properties
        model_dir = properties.get("model_dir")
        device = torch.device("cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() else "cpu")

        # Read model serialize/pt file
        serialized_file = manifest['model']['serializedFile']
        model_pt_path = os.path.join(model_dir, serialized_file)
        if not os.path.isfile(model_pt_path):
            raise RuntimeError("Missing the model.pt file")

        model = torch.jit.load(model_pt_path)
    else:
        #infer and return result
        return model(data)

此入口点在两种情况下被调用

  1. TorchServe 被要求扩展模型以增加后端工作程序的数量(通过 PUT /models/{model_name} 请求或带有 initial-workers 选项的 POST /models 请求来完成,或者在您使用 --models 选项(torchserve --start --models {model_name=model.mar})时在 TorchServe 启动期间完成,即您提供要加载的模型),即您提供要加载的模型)。

  2. TorchServe 收到 POST /predictions/{model_name} 请求。

(1) 用于扩展或缩减模型的工作程序。(2) 用作针对模型运行推理的标准方法。(1) 也称为模型加载时间。通常,您希望在模型加载时运行模型初始化的代码。您可以在 TorchServe 管理 APITorchServe 推理 API 中找到有关这些 API 和其他 TorchServe API 的更多信息。

使用 class 级入口点的自定义处理程序

您可以通过使用具有任何名称的类来创建自定义处理程序,但它必须具有一个 initialize 方法和一个 handle 方法。

注意 - 如果您计划在同一个 Python 模块/文件中包含多个类,请确保处理程序类是列表中的第一个

入口点类和函数的签名为

class ModelHandler(object):
    """
    A custom model handler implementation.
    """

    def __init__(self):
        self._context = None
        self.initialized = False
        self.model = None
        self.device = None

    def initialize(self, context):
        """
        Invoke by torchserve for loading a model
        :param context: context contains model server system properties
        :return:
        """

        #  load the model
        self.manifest = context.manifest

        properties = context.system_properties
        model_dir = properties.get("model_dir")
        self.device = torch.device("cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() else "cpu")

        # Read model serialize/pt file
        serialized_file = self.manifest['model']['serializedFile']
        model_pt_path = os.path.join(model_dir, serialized_file)
        if not os.path.isfile(model_pt_path):
            raise RuntimeError("Missing the model.pt file")

        self.model = torch.jit.load(model_pt_path)

        self.initialized = True


    def handle(self, data, context):
        """
        Invoke by TorchServe for prediction request.
        Do pre-processing of data, prediction using model and postprocessing of prediciton output
        :param data: Input data for prediction
        :param context: Initial context contains model server system properties.
        :return: prediction output
        """
        pred_out = self.model.forward(data)
        return pred_out

高级自定义处理程序

返回自定义错误代码

要通过使用 module 级入口点的自定义处理程序将自定义错误代码返回给用户。

from ts.utils.util import PredictionException
def handle(data, context):
    # Some unexpected error - returning error code 513
    raise PredictionException("Some Prediction Error", 513)

要通过使用 class 级入口点的自定义处理程序将自定义错误代码返回给用户。

from ts.torch_handler.base_handler import BaseHandler
from ts.utils.util import PredictionException

class ModelHandler(BaseHandler):
    """
    A custom model handler implementation.
    """

    def handle(self, data, context):
        # Some unexpected error - returning error code 513
        raise PredictionException("Some Prediction Error", 513)

从头开始编写用于预测和解释请求的自定义处理程序

您通常应该从 BaseHandler 派生,并且只覆盖需要更改行为的方法!如您在示例中所见,大多数情况下您只需要覆盖 preprocesspostprocess

但是,您可以从头开始编写一个类。下面是一个示例。它基本上遵循一个典型的 Init-Pre-Infer-Post 模式来创建可维护的自定义处理程序。

# custom handler file

# model_handler.py

"""
ModelHandler defines a custom model handler.
"""

from ts.torch_handler.base_handler import BaseHandler

class ModelHandler(BaseHandler):
    """
    A custom model handler implementation.
    """

    def __init__(self):
        self._context = None
        self.initialized = False
        self.explain = False
        self.target = 0

    def initialize(self, context):
        """
        Initialize model. This will be called during model loading time
        :param context: Initial context contains model server system properties.
        :return:
        """
        self._context = context
        self.initialized = True
        #  load the model, refer 'custom handler class' above for details

    def preprocess(self, data):
        """
        Transform raw input into model input data.
        :param batch: list of raw requests, should match batch size
        :return: list of preprocessed model input data
        """
        # Take the input data and make it inference ready
        preprocessed_data = data[0].get("data")
        if preprocessed_data is None:
            preprocessed_data = data[0].get("body")

        return preprocessed_data


    def inference(self, model_input):
        """
        Internal inference methods
        :param model_input: transformed model input data
        :return: list of inference output in NDArray
        """
        # Do some inference call to engine here and return output
        model_output = self.model.forward(model_input)
        return model_output

    def postprocess(self, inference_output):
        """
        Return inference result.
        :param inference_output: list of inference output
        :return: list of predict results
        """
        # Take output from network and post-process to desired format
        postprocess_output = inference_output
        return postprocess_output

    def handle(self, data, context):
        """
        Invoke by TorchServe for prediction request.
        Do pre-processing of data, prediction using model and postprocessing of prediciton output
        :param data: Input data for prediction
        :param context: Initial context contains model server system properties.
        :return: prediction output
        """
        model_input = self.preprocess(data)
        model_output = self.inference(model_input)
        return self.postprocess(model_output)

有关更多详细信息,请参考 waveglow_handler

自定义处理程序的 Captum 解释

Torchserve 返回图像分类、文本分类和 BERT 模型的 Captum 解释。它是通过放置以下请求来实现的:POST /explanations/{model_name}

解释被写为基本处理程序的 explain_handle 方法的一部分。基本处理程序调用此 explain_handle_method。传递给 explain handle 方法的参数是预处理数据和原始数据。它调用自定义处理程序的 get_insights 函数,该函数返回 Captum 属性。用户应该编写自己的 get_insights 功能来获取解释

为了提供自定义处理程序,Captum 算法应该在处理程序的 initialize 函数中初始化

用户可以覆盖自定义处理程序中的 explain_handle 函数。用户应该为自定义处理程序定义自己的 get_insights 方法以获取 Captum 属性。

上面的 ModelHandler 类应该具有以下带有 Captum 功能的方法。


    def initialize(self, context):
        """
        Load the model and its artifacts
        """
        .....
        self.lig = LayerIntegratedGradients(
                captum_sequence_forward, self.model.bert.embeddings
            )

    def handle(self, data, context):
        """
        Invoke by TorchServe for prediction/explanation request.
        Do pre-processing of data, prediction using model and postprocessing of prediction/explanations output
        :param data: Input data for prediction/explanation
        :param context: Initial context contains model server system properties.
        :return: prediction/ explanations output
        """
        model_input = self.preprocess(data)
        if not self._is_explain():
                model_output = self.inference(model_input)
                model_output = self.postprocess(model_output)
            else :
                model_output = self.explain_handle(model_input, data)
            return model_output
    
    # Present in the base_handler, so override only when neccessary
    def explain_handle(self, data_preprocess, raw_data):
        """Captum explanations handler

        Args:
            data_preprocess (Torch Tensor): Preprocessed data to be used for captum
            raw_data (list): The unprocessed data to get target from the request

        Returns:
            dict : A dictionary response with the explanations response.
        """
        output_explain = None
        inputs = None
        target = 0

        logger.info("Calculating Explanations")
        row = raw_data[0]
        if isinstance(row, dict):
            logger.info("Getting data and target")
            inputs = row.get("data") or row.get("body")
            target = row.get("target")
            if not target:
                target = 0

        output_explain = self.get_insights(data_preprocess, inputs, target)
        return output_explain

    def get_insights(self,**kwargs):
        """
        Functionality to get the explanations.
        Called from the explain_handle method 
        """
        pass

扩展默认处理程序

TorchServe 具有以下默认处理程序。

如果需要,上述处理程序可以扩展以创建自定义处理程序。此外,您可以扩展抽象的 base_handler

要将默认处理程序导入 Python 脚本中,请使用以下导入语句。

from ts.torch_handler.<default_handler_name> import <DefaultHandlerClass>

以下是一个扩展默认 image_classifier 处理程序的自定义处理程序示例。

from ts.torch_handler.image_classifier import ImageClassifier

class CustomImageClassifier(ImageClassifier):

    def preprocess(self, data):
        """
        Overriding this method for custom preprocessing.
        :param data: raw data to be transformed
        :return: preprocessed data for model input
        """
        # custom pre-procsess code goes here
        return data

有关更多详细信息,请参考以下示例

创建包含入口点的模型存档

TorchServe 从清单文件中识别自定义服务的入口点。创建模型存档时,请使用 --handler 选项指定入口点的位置。

model-archiver 工具可用于创建 TorchServe 可以服务的模型存档。

torch-model-archiver --model-name <model-name> --version <model_version_number> --handler model_handler[:<entry_point_function_name>] [--model-file <path_to_model_architecture_file>] --serialized-file <path_to_state_dict_file> [--extra-files <comma_seperarted_additional_files>] [--export-path <output-dir> --model-path <model_dir>] [--runtime python3]

注意 -

  1. 方括号 [] 中的选项是可选的。

  2. entry_point_function_name 如果在您的 处理程序模块 中被命名为 handle 或处理程序是 python 类,则可以省略。

这将在 <output-dir> 目录中创建文件 <model-name>.mar,用于 python3 运行时。 --runtime 参数允许在运行时使用特定的 python 版本。默认情况下它使用系统中的默认 python 分发版。

示例

torch-model-archiver --model-name waveglow_synthesizer --version 1.0 --model-file waveglow_model.py --serialized-file nvidia_waveglowpyt_fp32_20190306.pth --handler waveglow_handler.py --extra-files tacotron.zip,nvidia_tacotron2pyt_fp32_20190306.pth

处理在多个 GPU 上的模型执行

TorchServe 在 vCPU 或 GPU 上扩展后端工作程序。在多个 GPU 的情况下,TorchServe 以循环方式选择 GPU 设备,并将此设备 ID 传递到上下文对象中的模型处理程序。用户应使用此 GPU ID 创建 PyTorch 设备对象,以确保并非所有工作程序都在同一 GPU 上创建。以下代码片段可在模型处理程序中使用,以创建 PyTorch 设备对象。

import torch

class ModelHandler(object):
    """
    A base Model handler implementation.
    """

    def __init__(self):
        self.device = None

    def initialize(self, context):
        properties = context.system_properties
        self.device = torch.device("cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() else "cpu")

安装模型特定的 python 依赖项

自定义模型/处理程序可能依赖于不同的 python 包,这些包在 TorchServe 设置中默认情况下不会安装。

以下步骤允许用户提供一个自定义 python 包列表,供 TorchServe 安装,以实现无缝模型服务。

  1. 启用模型特定的 python 包安装

  2. 使用模型存档提供一个需求文件.

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源