⚠️ 注意:有限维护
此项目已不再积极维护。现有版本仍可使用,但没有计划的更新、错误修复、新功能或安全补丁。用户应注意,漏洞可能不会被解决。
定制服务¶
文档目录¶
定制处理程序¶
通过编写一个 Python 脚本来定制 TorchServe 的行为,在使用模型归档器时,将此脚本与模型一起打包。TorchServe 在运行时会执行此代码。
提供定制脚本以用于
初始化模型实例
在将输入数据发送到模型进行推理或 Captum 解释之前进行预处理
定制模型如何被调用进行推理或解释
在发送响应之前对模型输出进行后处理
以下内容适用于所有类型的定制处理程序
data - 来自传入请求的输入数据
context - 是 TorchServe 上下文。您可以使用以下信息进行定制:model_name, model_dir, manifest, batch_size, gpu 等。
从 BaseHandler 开始!¶
BaseHandler 实现了您所需的大部分功能。您可以从中派生一个新类,如示例和默认处理程序所示。大多数情况下,您只需要重写 preprocess
或 postprocess
方法。
具有 模块
级别入口点的定制处理程序¶
定制处理程序文件必须定义一个模块级别的函数作为执行的入口点。该函数可以有任何名称,但必须接受以下参数并返回预测结果。
入口点函数的签名是
# Create model object
model = None
def entry_point_function_name(data, context):
"""
Works on data and context to create model object or process inference request.
Following sample demonstrates how model object can be initialized for jit mode.
Similarly you can do it for eager mode models.
:param data: Input data for prediction
:param context: context contains model server system properties
:return: prediction output
"""
global model
if not data:
manifest = context.manifest
properties = context.system_properties
model_dir = properties.get("model_dir")
device = torch.device("cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() else "cpu")
# Read model serialize/pt file
serialized_file = manifest['model']['serializedFile']
model_pt_path = os.path.join(model_dir, serialized_file)
if not os.path.isfile(model_pt_path):
raise RuntimeError("Missing the model.pt file")
model = torch.jit.load(model_pt_path)
else:
#infer and return result
return model(data)
此入口点在以下两种情况下被调用
当要求 TorchServe 扩展模型以增加后端 worker 数量时(这可以通过
PUT /models/{model_name}
请求或带有initial-workers
选项的POST /models
请求完成,或者在 TorchServe 启动时使用--models
选项时完成(torchserve --start --models {model_name=model.mar}
),即您提供要加载的模型时)TorchServe 收到
POST /predictions/{model_name}
请求时。
(1) 用于扩展或缩减模型的 worker 数量。(2) 用作对模型运行推理的标准方式。(1) 也称为模型加载时间。通常,您希望模型初始化代码在模型加载时运行。您可以在 TorchServe 管理 API 和 TorchServe 推理 API 中找到有关这些和其他 TorchServe API 的更多信息。
具有 类
级别入口点的定制处理程序¶
您可以通过创建任何名称的类来创建定制处理程序,但该类必须包含 initialize
和 handle
方法。
注意 - 如果您计划在同一个 Python 模块/文件中包含多个类,请确保处理程序类是列表中的第一个。
入口点类和方法的签名是
class ModelHandler(object):
"""
A custom model handler implementation.
"""
def __init__(self):
self._context = None
self.initialized = False
self.model = None
self.device = None
def initialize(self, context):
"""
Invoke by torchserve for loading a model
:param context: context contains model server system properties
:return:
"""
# load the model
self.manifest = context.manifest
properties = context.system_properties
model_dir = properties.get("model_dir")
self.device = torch.device("cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() else "cpu")
# Read model serialize/pt file
serialized_file = self.manifest['model']['serializedFile']
model_pt_path = os.path.join(model_dir, serialized_file)
if not os.path.isfile(model_pt_path):
raise RuntimeError("Missing the model.pt file")
self.model = torch.jit.load(model_pt_path)
self.initialized = True
def handle(self, data, context):
"""
Invoke by TorchServe for prediction request.
Do pre-processing of data, prediction using model and postprocessing of prediciton output
:param data: Input data for prediction
:param context: Initial context contains model server system properties.
:return: prediction output
"""
pred_out = self.model.forward(data)
return pred_out
高级定制处理程序¶
返回定制错误代码¶
要通过具有 模块
级别入口点的定制处理程序向用户返回定制错误代码。
from ts.utils.util import PredictionException
def handle(data, context):
# Some unexpected error - returning error code 513
raise PredictionException("Some Prediction Error", 513)
要通过具有 类
级别入口点的定制处理程序向用户返回定制错误代码。
from ts.torch_handler.base_handler import BaseHandler
from ts.utils.util import PredictionException
class ModelHandler(BaseHandler):
"""
A custom model handler implementation.
"""
def handle(self, data, context):
# Some unexpected error - returning error code 513
raise PredictionException("Some Prediction Error", 513)
从头开始编写用于预测和解释请求的定制处理程序¶
通常应从 BaseHandler 派生,并且只重写需要更改行为的方法!正如您在示例中看到的那样,大多数情况下,您只需要重写 preprocess
或 postprocess
方法。
尽管如此,您仍然可以从头开始编写一个类。下面是一个示例。基本上,它遵循典型的初始化-预处理-推理-后处理(Init-Pre-Infer-Post)模式来创建可维护的定制处理程序。
# custom handler file
# model_handler.py
"""
ModelHandler defines a custom model handler.
"""
from ts.torch_handler.base_handler import BaseHandler
class ModelHandler(BaseHandler):
"""
A custom model handler implementation.
"""
def __init__(self):
self._context = None
self.initialized = False
self.explain = False
self.target = 0
def initialize(self, context):
"""
Initialize model. This will be called during model loading time
:param context: Initial context contains model server system properties.
:return:
"""
self._context = context
self.initialized = True
# load the model, refer 'custom handler class' above for details
def preprocess(self, data):
"""
Transform raw input into model input data.
:param batch: list of raw requests, should match batch size
:return: list of preprocessed model input data
"""
# Take the input data and make it inference ready
preprocessed_data = data[0].get("data")
if preprocessed_data is None:
preprocessed_data = data[0].get("body")
return preprocessed_data
def inference(self, model_input):
"""
Internal inference methods
:param model_input: transformed model input data
:return: list of inference output in NDArray
"""
# Do some inference call to engine here and return output
model_output = self.model.forward(model_input)
return model_output
def postprocess(self, inference_output):
"""
Return inference result.
:param inference_output: list of inference output
:return: list of predict results
"""
# Take output from network and post-process to desired format
postprocess_output = inference_output
return postprocess_output
def handle(self, data, context):
"""
Invoke by TorchServe for prediction request.
Do pre-processing of data, prediction using model and postprocessing of prediciton output
:param data: Input data for prediction
:param context: Initial context contains model server system properties.
:return: prediction output
"""
model_input = self.preprocess(data)
model_output = self.inference(model_input)
return self.postprocess(model_output)
有关更多详细信息,请参阅 waveglow_handler。
定制处理程序的 Captum 解释¶
Torchserve 返回图像分类、文本分类和 BERT 模型的 Captum 解释。这可以通过发送以下请求实现:POST /explanations/{model_name}
解释作为基础处理程序的 explain_handle 方法的一部分编写。基础处理程序调用此 explain_handle_method。传递给 explain handle 方法的参数是预处理后的数据和原始数据。它调用定制处理程序的 get_insights 函数,该函数返回 captum 归因。用户应自行编写 get_insights 功能以获取解释。
对于服务定制处理程序,Captum 算法应在处理程序的 initialize 函数中进行初始化。
用户可以在定制处理程序中重写 explain_handle 函数。用户应为其定制处理程序定义 get_insights 方法以获取 Captum 归因。
上述 ModelHandler 类应包含以下具有 Captum 功能的方法。
def initialize(self, context):
"""
Load the model and its artifacts
"""
.....
self.lig = LayerIntegratedGradients(
captum_sequence_forward, self.model.bert.embeddings
)
def handle(self, data, context):
"""
Invoke by TorchServe for prediction/explanation request.
Do pre-processing of data, prediction using model and postprocessing of prediction/explanations output
:param data: Input data for prediction/explanation
:param context: Initial context contains model server system properties.
:return: prediction/ explanations output
"""
model_input = self.preprocess(data)
if not self._is_explain():
model_output = self.inference(model_input)
model_output = self.postprocess(model_output)
else :
model_output = self.explain_handle(model_input, data)
return model_output
# Present in the base_handler, so override only when neccessary
def explain_handle(self, data_preprocess, raw_data):
"""Captum explanations handler
Args:
data_preprocess (Torch Tensor): Preprocessed data to be used for captum
raw_data (list): The unprocessed data to get target from the request
Returns:
dict : A dictionary response with the explanations response.
"""
output_explain = None
inputs = None
target = 0
logger.info("Calculating Explanations")
row = raw_data[0]
if isinstance(row, dict):
logger.info("Getting data and target")
inputs = row.get("data") or row.get("body")
target = row.get("target")
if not target:
target = 0
output_explain = self.get_insights(data_preprocess, inputs, target)
return output_explain
def get_insights(self,**kwargs):
"""
Functionality to get the explanations.
Called from the explain_handle method
"""
pass
扩展默认处理程序¶
TorchServe 包含以下默认处理程序。
如有需要,可以扩展上述处理程序以创建定制处理程序。此外,您还可以扩展抽象的 base_handler。
要在 Python 脚本中导入默认处理程序,请使用以下 import 语句。
from ts.torch_handler.<default_handler_name> import <DefaultHandlerClass>
以下是一个扩展默认 image_classifier 处理程序的定制处理程序示例。
from ts.torch_handler.image_classifier import ImageClassifier
class CustomImageClassifier(ImageClassifier):
def preprocess(self, data):
"""
Overriding this method for custom preprocessing.
:param data: raw data to be transformed
:return: preprocessed data for model input
"""
# custom pre-procsess code goes here
return data
有关更多详细信息,请参阅以下示例
创建包含入口点的模型归档¶
TorchServe 从清单文件中识别定制服务的入口点。创建模型归档时,使用 --handler
选项指定入口点的位置。
model-archiver 工具使您能够创建可由 TorchServe 服务的模型归档。
torch-model-archiver --model-name <model-name> --version <model_version_number> --handler model_handler[:<entry_point_function_name>] [--model-file <path_to_model_architecture_file>] --serialized-file <path_to_state_dict_file> [--extra-files <comma_seperarted_additional_files>] [--export-path <output-dir> --model-path <model_dir>] [--runtime python3]
注意 -
这会在目录 <output-dir>
中为 python3 运行时创建文件 <model-name>.mar
。 --runtime
参数允许在运行时使用特定的 python 版本。默认情况下,它使用系统的默认 python 分发版本。
示例
torch-model-archiver --model-name waveglow_synthesizer --version 1.0 --model-file waveglow_model.py --serialized-file nvidia_waveglowpyt_fp32_20190306.pth --handler waveglow_handler.py --extra-files tacotron.zip,nvidia_tacotron2pyt_fp32_20190306.pth
处理模型在多个 GPU 上的执行¶
TorchServe 在 vCPU 或 GPU 上扩展后端 worker。在多个 GPU 的情况下,TorchServe 以轮询方式选择 GPU 设备,并将此设备 ID 通过上下文对象传递给模型处理程序。用户应使用此 GPU ID 创建 PyTorch 设备对象,以确保所有 worker 不在同一个 GPU 上创建。以下代码片段可在模型处理程序中使用来创建 PyTorch 设备对象。
import torch
class ModelHandler(object):
"""
A base Model handler implementation.
"""
def __init__(self):
self.device = None
def initialize(self, context):
properties = context.system_properties
self.device = torch.device("cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() else "cpu")
安装模型特定的 Python 依赖项¶
定制模型/处理程序可能依赖于不同的 Python 包,这些包默认情况下不会作为 TorchServe
设置的一部分安装。
以下步骤允许用户提供 TorchServe 需要安装的定制 Python 包列表,以便无缝地服务模型。