• 教程 >
  • (原型) PyTorch BackendConfig 教程
快捷方式

(原型) PyTorch BackendConfig 教程

创建于:2023 年 1 月 3 日 | 最后更新:2023 年 1 月 18 日 | 最后验证:未验证

作者Andrew Or

BackendConfig API 使开发者能够将他们的后端与 PyTorch 量化集成。目前仅在 FX 图模式量化中受支持,但未来可能会扩展到其他量化模式。在本教程中,我们将演示如何使用此 API 自定义对特定后端的量化支持。有关 BackendConfig 背后的动机和实现细节的更多信息,请参阅此 README

假设我们是后端开发者,并且希望将我们的后端与 PyTorch 的量化 API 集成。我们的后端仅包含两个运算:量化线性和量化 conv-relu。在本节中,我们将逐步介绍如何通过 prepare_fxconvert_fx 使用自定义 BackendConfig 量化示例模型来实现此目的。

import torch
from torch.ao.quantization import (
    default_weight_observer,
    get_default_qconfig_mapping,
    MinMaxObserver,
    QConfig,
    QConfigMapping,
)
from torch.ao.quantization.backend_config import (
    BackendConfig,
    BackendPatternConfig,
    DTypeConfig,
    DTypeWithConstraints,
    ObservationType,
)
from torch.ao.quantization.quantize_fx import prepare_fx, convert_fx

1. 推导每个量化运算符的参考模式

对于量化线性,假设我们的后端期望参考模式 [dequant - fp32_linear - quant] 并将其降低为单个量化线性运算。实现此目的的方法是首先在浮点线性运算之前和之后插入量化-反量化运算,以便我们生成以下参考模型

quant1 - [dequant1 - fp32_linear - quant2] - dequant2

类似地,对于量化 conv-relu,我们希望生成以下参考模型,其中方括号中的参考模式将降低为单个量化 conv-relu 运算

quant1 - [dequant1 - fp32_conv_relu - quant2] - dequant2

2. 使用后端约束设置 DTypeConfig

在上面的参考模式中,DTypeConfig 中指定的输入 dtype 将作为 dtype 参数传递给 quant1,而输出 dtype 将作为 dtype 参数传递给 quant2。如果输出 dtype 是 fp32,如动态量化的情况,则不会插入输出量化-反量化对。此示例还展示了如何在特定 dtype 上指定对量化和比例范围的限制。

quint8_with_constraints = DTypeWithConstraints(
    dtype=torch.quint8,
    quant_min_lower_bound=0,
    quant_max_upper_bound=255,
    scale_min_lower_bound=2 ** -12,
)

# Specify the dtypes passed to the quantized ops in the reference model spec
weighted_int8_dtype_config = DTypeConfig(
    input_dtype=quint8_with_constraints,
    output_dtype=quint8_with_constraints,
    weight_dtype=torch.qint8,
    bias_dtype=torch.float)

3. 设置 conv-relu 的融合

请注意,原始用户模型包含单独的 conv 和 relu 运算,因此我们需要首先将 conv 和 relu 运算融合为单个 conv-relu 运算 (fp32_conv_relu),然后像量化线性运算一样量化此运算。我们可以通过定义一个接受 3 个参数的函数来设置融合,其中第一个参数指示这是否用于 QAT,其余参数引用融合模式的各个项。

def fuse_conv2d_relu(is_qat, conv, relu):
    """Return a fused ConvReLU2d from individual conv and relu modules."""
    return torch.ao.nn.intrinsic.ConvReLU2d(conv, relu)

4. 定义 BackendConfig

现在我们拥有了所有必要的组件,因此我们继续定义我们的 BackendConfig。在这里,我们为线性运算的输入和输出使用不同的观察者(将被重命名),因此传递给两个量化运算(quant1 和 quant2)的量化参数将有所不同。这对于像线性运算和 conv 这样的加权运算通常是这种情况。

对于 conv-relu 运算,观察类型相同。但是,我们需要两个 BackendPatternConfig 来支持此运算,一个用于融合,另一个用于量化。对于 conv-relu 和线性运算,我们都使用上面定义的 DTypeConfig。

linear_config = BackendPatternConfig() \
    .set_pattern(torch.nn.Linear) \
    .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \
    .add_dtype_config(weighted_int8_dtype_config) \
    .set_root_module(torch.nn.Linear) \
    .set_qat_module(torch.nn.qat.Linear) \
    .set_reference_quantized_module(torch.ao.nn.quantized.reference.Linear)

# For fusing Conv2d + ReLU into ConvReLU2d
# No need to set observation type and dtype config here, since we are not
# inserting quant-dequant ops in this step yet
conv_relu_config = BackendPatternConfig() \
    .set_pattern((torch.nn.Conv2d, torch.nn.ReLU)) \
    .set_fused_module(torch.ao.nn.intrinsic.ConvReLU2d) \
    .set_fuser_method(fuse_conv2d_relu)

# For quantizing ConvReLU2d
fused_conv_relu_config = BackendPatternConfig() \
    .set_pattern(torch.ao.nn.intrinsic.ConvReLU2d) \
    .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \
    .add_dtype_config(weighted_int8_dtype_config) \
    .set_root_module(torch.nn.Conv2d) \
    .set_qat_module(torch.ao.nn.intrinsic.qat.ConvReLU2d) \
    .set_reference_quantized_module(torch.ao.nn.quantized.reference.Conv2d)

backend_config = BackendConfig("my_backend") \
    .set_backend_pattern_config(linear_config) \
    .set_backend_pattern_config(conv_relu_config) \
    .set_backend_pattern_config(fused_conv_relu_config)

5. 设置满足后端约束的 QConfigMapping

为了使用上面定义的运算,用户必须定义一个 QConfig,该 QConfig 满足 DTypeConfig 中指定的约束。有关更多详细信息,请参阅 DTypeConfig 的文档。然后,我们将此 QConfig 用于我们要量化的模式中使用的所有模块。

# Note: Here we use a quant_max of 127, but this could be up to 255 (see `quint8_with_constraints`)
activation_observer = MinMaxObserver.with_args(quant_min=0, quant_max=127, eps=2 ** -12)
qconfig = QConfig(activation=activation_observer, weight=default_weight_observer)

# Note: All individual items of a fused pattern, e.g. Conv2d and ReLU in
# (Conv2d, ReLU), must have the same QConfig
qconfig_mapping = QConfigMapping() \
    .set_object_type(torch.nn.Linear, qconfig) \
    .set_object_type(torch.nn.Conv2d, qconfig) \
    .set_object_type(torch.nn.BatchNorm2d, qconfig) \
    .set_object_type(torch.nn.ReLU, qconfig)

6. 通过 prepare 和 convert 量化模型

最后,我们通过将我们定义的 BackendConfig 传递到 prepare 和 convert 来量化模型。这将生成一个量化的线性模块和一个融合的量化 conv-relu 模块。

class MyModel(torch.nn.Module):
    def __init__(self, use_bn: bool):
        super().__init__()
        self.linear = torch.nn.Linear(10, 3)
        self.conv = torch.nn.Conv2d(3, 3, 3)
        self.bn = torch.nn.BatchNorm2d(3)
        self.relu = torch.nn.ReLU()
        self.sigmoid = torch.nn.Sigmoid()
        self.use_bn = use_bn

    def forward(self, x):
        x = self.linear(x)
        x = self.conv(x)
        if self.use_bn:
            x = self.bn(x)
        x = self.relu(x)
        x = self.sigmoid(x)
        return x

example_inputs = (torch.rand(1, 3, 10, 10, dtype=torch.float),)
model = MyModel(use_bn=False)
prepared = prepare_fx(model, qconfig_mapping, example_inputs, backend_config=backend_config)
prepared(*example_inputs)  # calibrate
converted = convert_fx(prepared, backend_config=backend_config)
>>> print(converted)

GraphModule(
  (linear): QuantizedLinear(in_features=10, out_features=3, scale=0.012136868201196194, zero_point=67, qscheme=torch.per_tensor_affine)
  (conv): QuantizedConvReLU2d(3, 3, kernel_size=(3, 3), stride=(1, 1), scale=0.0029353597201406956, zero_point=0)
  (sigmoid): Sigmoid()
)

def forward(self, x):
    linear_input_scale_0 = self.linear_input_scale_0
    linear_input_zero_point_0 = self.linear_input_zero_point_0
    quantize_per_tensor = torch.quantize_per_tensor(x, linear_input_scale_0, linear_input_zero_point_0, torch.quint8);  x = linear_input_scale_0 = linear_input_zero_point_0 = None
    linear = self.linear(quantize_per_tensor);  quantize_per_tensor = None
    conv = self.conv(linear);  linear = None
    dequantize_2 = conv.dequantize();  conv = None
    sigmoid = self.sigmoid(dequantize_2);  dequantize_2 = None
    return sigmoid

(7. 实验错误的 BackendConfig 设置)

作为一项实验,在这里我们将模型修改为使用 conv-bn-relu 而不是 conv-relu,但使用相同的 BackendConfig,后者不知道如何量化 conv-bn-relu。结果,仅量化了线性,而 conv-bn-relu 既未融合也未量化。

>>> print(converted)

GraphModule(
  (linear): QuantizedLinear(in_features=10, out_features=3, scale=0.015307803638279438, zero_point=95, qscheme=torch.per_tensor_affine)
  (conv): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1))
  (bn): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU()
  (sigmoid): Sigmoid()
)

def forward(self, x):
    linear_input_scale_0 = self.linear_input_scale_0
    linear_input_zero_point_0 = self.linear_input_zero_point_0
    quantize_per_tensor = torch.quantize_per_tensor(x, linear_input_scale_0, linear_input_zero_point_0, torch.quint8);  x = linear_input_scale_0 = linear_input_zero_point_0 = None
    linear = self.linear(quantize_per_tensor);  quantize_per_tensor = None
    dequantize_1 = linear.dequantize();  linear = None
    conv = self.conv(dequantize_1);  dequantize_1 = None
    bn = self.bn(conv);  conv = None
    relu = self.relu(bn);  bn = None
    sigmoid = self.sigmoid(relu);  relu = None
    return sigmoid

作为另一项实验,在这里我们使用不满足后端中指定的 dtype 约束的默认 QConfigMapping。结果,由于 QConfig 被简单地忽略,因此没有任何东西被量化。

>>> print(converted)

GraphModule(
  (linear): Linear(in_features=10, out_features=3, bias=True)
  (conv): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1))
  (bn): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU()
  (sigmoid): Sigmoid()
)

def forward(self, x):
    linear = self.linear(x);  x = None
    conv = self.conv(linear);  linear = None
    bn = self.bn(conv);  conv = None
    relu = self.relu(bn);  bn = None
    sigmoid = self.sigmoid(relu);  relu = None
    return sigmoid

内置 BackendConfig

PyTorch 量化在 torch.ao.quantization.backend_config 命名空间下支持一些内置的本机 BackendConfig

还有其他正在开发的 BackendConfig(例如,用于 TensorRT 和 x86),但这些目前仍大多处于实验阶段。如果用户希望将新的自定义后端与 PyTorch 的量化 API 集成,他们可以使用与定义本机支持的 BackendConfig 相同的 API 集定义自己的 BackendConfig,如上面的示例所示。


评价本教程

© Copyright 2024, PyTorch.

使用 Sphinx 构建,主题由 theme 提供,由 Read the Docs 提供支持。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题解答

查看资源