快捷方式

prepare_fx

class torch.ao.quantization.quantize_fx.prepare_fx(model, qconfig_mapping, example_inputs, prepare_custom_config=None, _equalization_config=None, backend_config=None)[source]

准备模型以进行训练后量化

参数
  • model (*) – torch.nn.Module 模型

  • qconfig_mapping (*) – QConfigMapping 对象,用于配置模型的量化方式,有关详细信息,请参见 QConfigMapping

  • example_inputs (*) – 模型前向函数的示例输入,位置参数元组(关键字参数也可以作为位置参数传递)

  • prepare_custom_config (*) – 量化工具的自定义配置。有关详细信息,请参见 PrepareCustomConfig

  • _equalization_config (*) – 用于指定如何在模型上执行均衡化的配置

  • backend_config (*) – 指定如何在后端中量化运算符的配置,这包括如何观察运算符、支持的融合模式、如何插入量化/反量化运算符、支持的类型等。有关详细信息,请参见 BackendConfig

返回值

具有观察器(由 qconfig_mapping 配置)的 GraphModule,已准备好进行校准

返回类型

GraphModule

示例

import torch
from torch.ao.quantization import get_default_qconfig_mapping
from torch.ao.quantization.quantize_fx import prepare_fx

class Submodule(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.linear = torch.nn.Linear(5, 5)
    def forward(self, x):
        x = self.linear(x)
        return x

class M(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.linear = torch.nn.Linear(5, 5)
        self.sub = Submodule()

    def forward(self, x):
        x = self.linear(x)
        x = self.sub(x) + x
        return x

# initialize a floating point model
float_model = M().eval()

# define calibration function
def calibrate(model, data_loader):
    model.eval()
    with torch.no_grad():
        for image, target in data_loader:
            model(image)

# qconfig is the configuration for how we insert observers for a particular
# operator
# qconfig = get_default_qconfig("fbgemm")
# Example of customizing qconfig:
# qconfig = torch.ao.quantization.QConfig(
#    activation=MinMaxObserver.with_args(dtype=torch.qint8),
#    weight=MinMaxObserver.with_args(dtype=torch.qint8))
# `activation` and `weight` are constructors of observer module

# qconfig_mapping is a collection of quantization configurations, user can
# set the qconfig for each operator (torch op calls, functional calls, module calls)
# in the model through qconfig_mapping
# the following call will get the qconfig_mapping that works best for models
# that target "fbgemm" backend
qconfig_mapping = get_default_qconfig_mapping("fbgemm")

# We can customize qconfig_mapping in different ways.
# e.g. set the global qconfig, which means we will use the same qconfig for
# all operators in the model, this can be overwritten by other settings
# qconfig_mapping = QConfigMapping().set_global(qconfig)
# e.g. quantize the linear submodule with a specific qconfig
# qconfig_mapping = QConfigMapping().set_module_name("linear", qconfig)
# e.g. quantize all nn.Linear modules with a specific qconfig
# qconfig_mapping = QConfigMapping().set_object_type(torch.nn.Linear, qconfig)
# for a more complete list, please see the docstring for :class:`torch.ao.quantization.QConfigMapping`
# argument

# example_inputs is a tuple of inputs, that is used to infer the type of the
# outputs in the model
# currently it's not used, but please make sure model(*example_inputs) runs
example_inputs = (torch.randn(1, 3, 224, 224),)

# TODO: add backend_config after we split the backend_config for fbgemm and qnnpack
# e.g. backend_config = get_default_backend_config("fbgemm")
# `prepare_fx` inserts observers in the model based on qconfig_mapping and
# backend_config. If the configuration for an operator in qconfig_mapping
# is supported in the backend_config (meaning it's supported by the target
# hardware), we'll insert observer modules according to the qconfig_mapping
# otherwise the configuration in qconfig_mapping will be ignored
#
# Example:
# in qconfig_mapping, user sets linear module to be quantized with quint8 for
# activation and qint8 for weight:
# qconfig = torch.ao.quantization.QConfig(
#     observer=MinMaxObserver.with_args(dtype=torch.quint8),
#     weight=MinMaxObserver.with-args(dtype=torch.qint8))
# Note: current qconfig api does not support setting output observer, but
# we may extend this to support these more fine grained control in the
# future
#
# qconfig_mapping = QConfigMapping().set_object_type(torch.nn.Linear, qconfig)
# in backend config, linear module also supports in this configuration:
# weighted_int8_dtype_config = DTypeConfig(
#   input_dtype=torch.quint8,
#   output_dtype=torch.quint8,
#   weight_dtype=torch.qint8,
#   bias_type=torch.float)

# linear_pattern_config = BackendPatternConfig(torch.nn.Linear) \
#    .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \
#    .add_dtype_config(weighted_int8_dtype_config) \
#    ...

# backend_config = BackendConfig().set_backend_pattern_config(linear_pattern_config)
# `prepare_fx` will check that the setting requested by suer in qconfig_mapping
# is supported by the backend_config and insert observers and fake quant modules
# in the model
prepared_model = prepare_fx(float_model, qconfig_mapping, example_inputs)
# Run calibration
calibrate(prepared_model, sample_inference_data)

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源