prepare_fx¶
- class torch.ao.quantization.quantize_fx.prepare_fx(model, qconfig_mapping, example_inputs, prepare_custom_config=None, _equalization_config=None, backend_config=None)[source]¶
准备模型以进行训练后量化
- 参数
model (*) – torch.nn.Module 模型
qconfig_mapping (*) – QConfigMapping 对象,用于配置模型的量化方式,有关详细信息,请参见
QConfigMapping
example_inputs (*) – 模型前向函数的示例输入,位置参数元组(关键字参数也可以作为位置参数传递)
prepare_custom_config (*) – 量化工具的自定义配置。有关详细信息,请参见
PrepareCustomConfig
_equalization_config (*) – 用于指定如何在模型上执行均衡化的配置
backend_config (*) – 指定如何在后端中量化运算符的配置,这包括如何观察运算符、支持的融合模式、如何插入量化/反量化运算符、支持的类型等。有关详细信息,请参见
BackendConfig
- 返回值
具有观察器(由 qconfig_mapping 配置)的 GraphModule,已准备好进行校准
- 返回类型
示例
import torch from torch.ao.quantization import get_default_qconfig_mapping from torch.ao.quantization.quantize_fx import prepare_fx class Submodule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.linear = torch.nn.Linear(5, 5) def forward(self, x): x = self.linear(x) return x class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.linear = torch.nn.Linear(5, 5) self.sub = Submodule() def forward(self, x): x = self.linear(x) x = self.sub(x) + x return x # initialize a floating point model float_model = M().eval() # define calibration function def calibrate(model, data_loader): model.eval() with torch.no_grad(): for image, target in data_loader: model(image) # qconfig is the configuration for how we insert observers for a particular # operator # qconfig = get_default_qconfig("fbgemm") # Example of customizing qconfig: # qconfig = torch.ao.quantization.QConfig( # activation=MinMaxObserver.with_args(dtype=torch.qint8), # weight=MinMaxObserver.with_args(dtype=torch.qint8)) # `activation` and `weight` are constructors of observer module # qconfig_mapping is a collection of quantization configurations, user can # set the qconfig for each operator (torch op calls, functional calls, module calls) # in the model through qconfig_mapping # the following call will get the qconfig_mapping that works best for models # that target "fbgemm" backend qconfig_mapping = get_default_qconfig_mapping("fbgemm") # We can customize qconfig_mapping in different ways. # e.g. set the global qconfig, which means we will use the same qconfig for # all operators in the model, this can be overwritten by other settings # qconfig_mapping = QConfigMapping().set_global(qconfig) # e.g. quantize the linear submodule with a specific qconfig # qconfig_mapping = QConfigMapping().set_module_name("linear", qconfig) # e.g. quantize all nn.Linear modules with a specific qconfig # qconfig_mapping = QConfigMapping().set_object_type(torch.nn.Linear, qconfig) # for a more complete list, please see the docstring for :class:`torch.ao.quantization.QConfigMapping` # argument # example_inputs is a tuple of inputs, that is used to infer the type of the # outputs in the model # currently it's not used, but please make sure model(*example_inputs) runs example_inputs = (torch.randn(1, 3, 224, 224),) # TODO: add backend_config after we split the backend_config for fbgemm and qnnpack # e.g. backend_config = get_default_backend_config("fbgemm") # `prepare_fx` inserts observers in the model based on qconfig_mapping and # backend_config. If the configuration for an operator in qconfig_mapping # is supported in the backend_config (meaning it's supported by the target # hardware), we'll insert observer modules according to the qconfig_mapping # otherwise the configuration in qconfig_mapping will be ignored # # Example: # in qconfig_mapping, user sets linear module to be quantized with quint8 for # activation and qint8 for weight: # qconfig = torch.ao.quantization.QConfig( # observer=MinMaxObserver.with_args(dtype=torch.quint8), # weight=MinMaxObserver.with-args(dtype=torch.qint8)) # Note: current qconfig api does not support setting output observer, but # we may extend this to support these more fine grained control in the # future # # qconfig_mapping = QConfigMapping().set_object_type(torch.nn.Linear, qconfig) # in backend config, linear module also supports in this configuration: # weighted_int8_dtype_config = DTypeConfig( # input_dtype=torch.quint8, # output_dtype=torch.quint8, # weight_dtype=torch.qint8, # bias_type=torch.float) # linear_pattern_config = BackendPatternConfig(torch.nn.Linear) \ # .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \ # .add_dtype_config(weighted_int8_dtype_config) \ # ... # backend_config = BackendConfig().set_backend_pattern_config(linear_pattern_config) # `prepare_fx` will check that the setting requested by suer in qconfig_mapping # is supported by the backend_config and insert observers and fake quant modules # in the model prepared_model = prepare_fx(float_model, qconfig_mapping, example_inputs) # Run calibration calibrate(prepared_model, sample_inference_data)