快捷方式

prepare_qat_fx

class torch.ao.quantization.quantize_fx.prepare_qat_fx(model, qconfig_mapping, example_inputs, prepare_custom_config=None, backend_config=None)[源代码]

为量化感知训练准备模型

参数
  • model (*) – torch.nn.Module 模型

  • qconfig_mapping (*) – 请参阅 prepare_fx()

  • example_inputs (*) – 请参阅 prepare_fx()

  • prepare_custom_config (*) – 请参阅 prepare_fx()

  • backend_config (*) – 请参阅 prepare_fx()

返回值

包含伪量化模块(由 qconfig_mapping 和 backend_config 配置)的 GraphModule,已准备好进行量化感知训练

返回类型

GraphModule

示例

import torch
from torch.ao.quantization import get_default_qat_qconfig_mapping
from torch.ao.quantization.quantize_fx import prepare_qat_fx

class Submodule(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.linear = torch.nn.Linear(5, 5)
    def forward(self, x):
        x = self.linear(x)
        return x

class M(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.linear = torch.nn.Linear(5, 5)
        self.sub = Submodule()

    def forward(self, x):
        x = self.linear(x)
        x = self.sub(x) + x
        return x

# initialize a floating point model
float_model = M().train()
# (optional, but preferred) load the weights from pretrained model
# float_model.load_weights(...)

# define the training loop for quantization aware training
def train_loop(model, train_data):
    model.train()
    for image, target in data_loader:
        ...

# qconfig is the configuration for how we insert observers for a particular
# operator
# qconfig = get_default_qconfig("fbgemm")
# Example of customizing qconfig:
# qconfig = torch.ao.quantization.QConfig(
#    activation=FakeQuantize.with_args(observer=MinMaxObserver.with_args(dtype=torch.qint8)),
#    weight=FakeQuantize.with_args(observer=MinMaxObserver.with_args(dtype=torch.qint8)))
# `activation` and `weight` are constructors of observer module

# qconfig_mapping is a collection of quantization configurations, user can
# set the qconfig for each operator (torch op calls, functional calls, module calls)
# in the model through qconfig_mapping
# the following call will get the qconfig_mapping that works best for models
# that target "fbgemm" backend
qconfig_mapping = get_default_qat_qconfig("fbgemm")

# We can customize qconfig_mapping in different ways, please take a look at
# the docstring for :func:`~torch.ao.quantization.prepare_fx` for different ways
# to configure this

# example_inputs is a tuple of inputs, that is used to infer the type of the
# outputs in the model
# currently it's not used, but please make sure model(*example_inputs) runs
example_inputs = (torch.randn(1, 3, 224, 224),)

# TODO: add backend_config after we split the backend_config for fbgemm and qnnpack
# e.g. backend_config = get_default_backend_config("fbgemm")
# `prepare_qat_fx` inserts observers in the model based on qconfig_mapping and
# backend_config, if the configuration for an operator in qconfig_mapping
# is supported in the backend_config (meaning it's supported by the target
# hardware), we'll insert fake_quantize modules according to the qconfig_mapping
# otherwise the configuration in qconfig_mapping will be ignored
# see :func:`~torch.ao.quantization.prepare_fx` for a detailed explanation of
# how qconfig_mapping interacts with backend_config
prepared_model = prepare_qat_fx(float_model, qconfig_mapping, example_inputs)
# Run training
train_loop(prepared_model, train_loop)

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源