量化¶
警告
量化功能目前处于 Beta 阶段,可能会发生更改。
量化简介¶
量化是指以低于浮点精度的位宽执行计算和存储张量的技术。量化模型在精度降低的张量上(而不是全精度(浮点)值)执行部分或全部操作。这允许更紧凑的模型表示,并在许多硬件平台上使用高性能向量化操作。与典型的 FP32 模型相比,PyTorch 支持 INT8 量化,从而可以将模型大小减少 4 倍,并将内存带宽需求减少 4 倍。与 FP32 计算相比,硬件对 INT8 计算的支持通常快 2 到 4 倍。量化主要是一种加速推理的技术,量化算子仅支持前向传播。
PyTorch 支持多种量化深度学习模型的方法。在大多数情况下,模型在 FP32 中训练,然后将模型转换为 INT8。此外,PyTorch 还支持量化感知训练,该训练使用伪量化模块在前向和后向传播中模拟量化误差。请注意,整个计算均以浮点执行。在量化感知训练结束时,PyTorch 提供转换函数以将训练后的模型转换为较低精度。
在较低级别,PyTorch 提供了一种表示量化张量并使用它们执行操作的方法。它们可用于直接构建模型,以较低精度执行全部或部分计算。PyTorch 提供了更高级别的 API,其中包含将 FP32 模型转换为较低精度的典型工作流程,同时最大限度地减少精度损失。
量化 API 摘要¶
PyTorch 提供了三种不同的量化模式:Eager 模式量化、FX 图模式量化(维护模式)和 PyTorch 2 导出量化。
Eager 模式量化是一项 Beta 功能。用户需要手动执行融合,并指定量化和反量化发生的位置,并且它仅支持模块,不支持函数式。
FX 图模式量化是 PyTorch 中的自动化量化工作流程,目前是一项原型功能,自我们推出 PyTorch 2 导出量化以来,它处于维护模式。它通过添加对函数式的支持并自动化量化过程,改进了 Eager 模式量化,尽管人们可能需要重构模型以使模型与 FX 图模式量化兼容(使用 torch.fx
进行符号追踪)。请注意,FX 图模式量化预计无法在任意模型上工作,因为模型可能无法进行符号追踪,我们将将其集成到 torchvision 等领域库中,用户将能够使用 FX 图模式量化来量化类似于受支持领域库中模型的模型。对于任意模型,我们将提供一般指南,但要使其真正发挥作用,用户可能需要熟悉 torch.fx
,尤其是在如何使模型可进行符号追踪方面。
PyTorch 2 导出量化是新的全图模式量化工作流程,在 PyTorch 2.1 中作为原型功能发布。借助 PyTorch 2,我们正在转向更好的全程序捕获解决方案 (torch.export),因为它与 torch.fx.symbolic_trace(14K 模型上的 72.7%)相比,可以捕获更高比例的模型(14K 模型上的 88.8%),而 torch.fx.symbolic_trace 是 FX 图模式量化使用的程序捕获解决方案。torch.export 在某些 python 构造方面仍然存在限制,并且需要用户参与以支持导出模型中的动态性,但总的来说,它是对以前程序捕获解决方案的改进。PyTorch 2 导出量化是为 torch.export 捕获的模型而构建的,同时考虑了建模用户和后端开发人员的灵活性和生产力。主要功能包括:(1). 可编程 API,用于配置模型的量化方式,可以扩展到更多用例 (2). 简化的建模用户和后端开发人员 UX,因为他们只需要与单个对象(Quantizer)交互,即可表达用户关于如何量化模型以及后端支持的意图。(3). 可选的参考量化模型表示,可以使用整数运算来表示量化计算,从而更接近硬件中发生的实际量化计算。
鼓励量化新用户首先试用 PyTorch 2 导出量化,如果效果不佳,用户可以尝试 Eager 模式量化。
下表比较了 Eager 模式量化、FX 图模式量化和 PyTorch 2 导出量化之间的差异
Eager 模式量化 |
FX 图模式量化 |
PyTorch 2 导出量化 |
|
发布状态 |
beta |
prototype(维护模式) |
prototype |
算子融合 |
手动 |
自动 |
自动 |
量化/反量化放置 |
手动 |
自动 |
自动 |
量化模块 |
支持 |
支持 |
支持 |
量化函数式/Torch 算子 |
手动 |
自动 |
支持 |
自定义支持 |
有限支持 |
完全支持 |
完全支持 |
量化模式支持 |
训练后量化:静态、动态、仅权重 量化感知训练:静态 |
训练后量化:静态、动态、仅权重 量化感知训练:静态 |
由后端特定的量化器定义 |
输入/输出模型类型 |
|
|
|
支持三种类型的量化
动态量化(权重被量化,激活以浮点读取/存储,并量化用于计算)
静态量化(权重被量化,激活被量化,训练后需要校准)
静态量化感知训练(权重被量化,激活被量化,训练期间对量化数值进行建模)
请参阅我们的 PyTorch 量化简介 博客文章,以更全面地了解这些量化类型之间的权衡。
算子覆盖率因动态量化和静态量化而异,并在下表中捕获。
静态量化 |
动态量化 |
|
nn.Linear
nn.Conv1d/2d/3d
|
是
是
|
是
否
|
nn.LSTM
nn.GRU
|
是(通过
自定义模块)
否
|
是
是
|
nn.RNNCell
nn.GRUCell
nn.LSTMCell
|
否
否
否
|
是
是
是
|
nn.EmbeddingBag |
是(激活为 fp32) |
是 |
nn.Embedding |
是 |
是 |
nn.MultiheadAttention |
是(通过自定义模块) |
不支持 |
激活 |
广泛支持 |
未更改,计算保持在 fp32 中 |
Eager 模式量化¶
有关量化流程的一般介绍,包括不同类型的量化,请查看 通用量化流程。
训练后动态量化¶
这是最容易应用的量化形式,其中权重提前量化,但激活在推理期间动态量化。这适用于模型执行时间主要由从内存加载权重而不是计算矩阵乘法决定的情况。对于具有小批量大小的 LSTM 和 Transformer 类型模型来说,情况确实如此。
图表
# original model
# all tensors and computations are in floating point
previous_layer_fp32 -- linear_fp32 -- activation_fp32 -- next_layer_fp32
/
linear_weight_fp32
# dynamically quantized model
# linear and LSTM weights are in int8
previous_layer_fp32 -- linear_int8_w_fp32_inp -- activation_fp32 -- next_layer_fp32
/
linear_weight_int8
PTDQ API 示例
import torch
# define a floating point model
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc = torch.nn.Linear(4, 4)
def forward(self, x):
x = self.fc(x)
return x
# create a model instance
model_fp32 = M()
# create a quantized model instance
model_int8 = torch.ao.quantization.quantize_dynamic(
model_fp32, # the original model
{torch.nn.Linear}, # a set of layers to dynamically quantize
dtype=torch.qint8) # the target dtype for quantized weights
# run the model
input_fp32 = torch.randn(4, 4, 4, 4)
res = model_int8(input_fp32)
要了解有关动态量化的更多信息,请参阅我们的 动态量化教程。
训练后静态量化¶
训练后静态量化(PTQ 静态)量化模型的权重和激活。它尽可能将激活融合到前面的层中。它需要使用代表性数据集进行校准,以确定激活的最佳量化参数。当内存带宽和计算节省都很重要时,通常使用训练后静态量化,CNN 是典型的用例。
在应用训练后静态量化之前,我们可能需要修改模型。请参阅 Eager 模式静态量化的模型准备。
图表
# original model
# all tensors and computations are in floating point
previous_layer_fp32 -- linear_fp32 -- activation_fp32 -- next_layer_fp32
/
linear_weight_fp32
# statically quantized model
# weights and activations are in int8
previous_layer_int8 -- linear_with_activation_int8 -- next_layer_int8
/
linear_weight_int8
PTSQ API 示例
import torch
# define a floating point model where some layers could be statically quantized
class M(torch.nn.Module):
def __init__(self):
super().__init__()
# QuantStub converts tensors from floating point to quantized
self.quant = torch.ao.quantization.QuantStub()
self.conv = torch.nn.Conv2d(1, 1, 1)
self.relu = torch.nn.ReLU()
# DeQuantStub converts tensors from quantized to floating point
self.dequant = torch.ao.quantization.DeQuantStub()
def forward(self, x):
# manually specify where tensors will be converted from floating
# point to quantized in the quantized model
x = self.quant(x)
x = self.conv(x)
x = self.relu(x)
# manually specify where tensors will be converted from quantized
# to floating point in the quantized model
x = self.dequant(x)
return x
# create a model instance
model_fp32 = M()
# model must be set to eval mode for static quantization logic to work
model_fp32.eval()
# attach a global qconfig, which contains information about what kind
# of observers to attach. Use 'x86' for server inference and 'qnnpack'
# for mobile inference. Other quantization configurations such as selecting
# symmetric or asymmetric quantization and MinMax or L2Norm calibration techniques
# can be specified here.
# Note: the old 'fbgemm' is still available but 'x86' is the recommended default
# for server inference.
# model_fp32.qconfig = torch.ao.quantization.get_default_qconfig('fbgemm')
model_fp32.qconfig = torch.ao.quantization.get_default_qconfig('x86')
# Fuse the activations to preceding layers, where applicable.
# This needs to be done manually depending on the model architecture.
# Common fusions include `conv + relu` and `conv + batchnorm + relu`
model_fp32_fused = torch.ao.quantization.fuse_modules(model_fp32, [['conv', 'relu']])
# Prepare the model for static quantization. This inserts observers in
# the model that will observe activation tensors during calibration.
model_fp32_prepared = torch.ao.quantization.prepare(model_fp32_fused)
# calibrate the prepared model to determine quantization parameters for activations
# in a real world setting, the calibration would be done with a representative dataset
input_fp32 = torch.randn(4, 1, 4, 4)
model_fp32_prepared(input_fp32)
# Convert the observed model to a quantized model. This does several things:
# quantizes the weights, computes and stores the scale and bias value to be
# used with each activation tensor, and replaces key operators with quantized
# implementations.
model_int8 = torch.ao.quantization.convert(model_fp32_prepared)
# run the model, relevant calculations will happen in int8
res = model_int8(input_fp32)
要了解有关静态量化的更多信息,请参阅 静态量化教程。
静态量化感知训练¶
量化感知训练 (QAT) 在训练期间模拟量化的影响,从而实现比其他量化方法更高的精度。我们可以对静态、动态或仅权重量化执行 QAT。在训练期间,所有计算均以浮点完成,其中 fake_quant 模块通过钳位和舍入来模拟量化的影响,以模拟 INT8 的效果。在模型转换后,权重和激活被量化,并且激活被融合到前面的层中(如果可能)。它通常与 CNN 一起使用,并且与静态量化相比,可产生更高的精度。
在应用训练后静态量化之前,我们可能需要修改模型。请参阅 Eager 模式静态量化的模型准备。
图表
# original model
# all tensors and computations are in floating point
previous_layer_fp32 -- linear_fp32 -- activation_fp32 -- next_layer_fp32
/
linear_weight_fp32
# model with fake_quants for modeling quantization numerics during training
previous_layer_fp32 -- fq -- linear_fp32 -- activation_fp32 -- fq -- next_layer_fp32
/
linear_weight_fp32 -- fq
# quantized model
# weights and activations are in int8
previous_layer_int8 -- linear_with_activation_int8 -- next_layer_int8
/
linear_weight_int8
QAT API 示例
import torch
# define a floating point model where some layers could benefit from QAT
class M(torch.nn.Module):
def __init__(self):
super().__init__()
# QuantStub converts tensors from floating point to quantized
self.quant = torch.ao.quantization.QuantStub()
self.conv = torch.nn.Conv2d(1, 1, 1)
self.bn = torch.nn.BatchNorm2d(1)
self.relu = torch.nn.ReLU()
# DeQuantStub converts tensors from quantized to floating point
self.dequant = torch.ao.quantization.DeQuantStub()
def forward(self, x):
x = self.quant(x)
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
x = self.dequant(x)
return x
# create a model instance
model_fp32 = M()
# model must be set to eval for fusion to work
model_fp32.eval()
# attach a global qconfig, which contains information about what kind
# of observers to attach. Use 'x86' for server inference and 'qnnpack'
# for mobile inference. Other quantization configurations such as selecting
# symmetric or asymmetric quantization and MinMax or L2Norm calibration techniques
# can be specified here.
# Note: the old 'fbgemm' is still available but 'x86' is the recommended default
# for server inference.
# model_fp32.qconfig = torch.ao.quantization.get_default_qconfig('fbgemm')
model_fp32.qconfig = torch.ao.quantization.get_default_qat_qconfig('x86')
# fuse the activations to preceding layers, where applicable
# this needs to be done manually depending on the model architecture
model_fp32_fused = torch.ao.quantization.fuse_modules(model_fp32,
[['conv', 'bn', 'relu']])
# Prepare the model for QAT. This inserts observers and fake_quants in
# the model needs to be set to train for QAT logic to work
# the model that will observe weight and activation tensors during calibration.
model_fp32_prepared = torch.ao.quantization.prepare_qat(model_fp32_fused.train())
# run the training loop (not shown)
training_loop(model_fp32_prepared)
# Convert the observed model to a quantized model. This does several things:
# quantizes the weights, computes and stores the scale and bias value to be
# used with each activation tensor, fuses modules where appropriate,
# and replaces key operators with quantized implementations.
model_fp32_prepared.eval()
model_int8 = torch.ao.quantization.convert(model_fp32_prepared)
# run the model, relevant calculations will happen in int8
res = model_int8(input_fp32)
要了解有关量化感知训练的更多信息,请参阅 QAT 教程。
Eager 模式静态量化的模型准备¶
目前,有必要在 Eager 模式量化之前对模型定义进行一些修改。这是因为目前量化是按模块进行的。具体而言,对于所有量化技术,用户都需要
将任何需要输出重新量化(并因此具有附加参数)的操作从函数式转换为模块形式(例如,使用
torch.nn.ReLU
而不是torch.nn.functional.relu
)。通过在子模块上分配
.qconfig
属性或通过指定qconfig_mapping
,指定模型的哪些部分需要量化。例如,设置model.conv1.qconfig = None
表示model.conv
层不会被量化,而设置model.linear1.qconfig = custom_qconfig
表示model.linear1
的量化设置将使用custom_qconfig
而不是全局 qconfig。
对于量化激活的静态量化技术,用户还需要执行以下操作
指定激活量化和反量化的位置。这是使用
QuantStub
和DeQuantStub
模块完成的。使用
FloatFunctional
将张量操作(需要特殊处理才能进行量化)包装到模块中。例如,像add
和cat
这样的操作,需要特殊处理来确定输出量化参数。融合模块:将操作/模块组合成单个模块,以获得更高的精度和性能。这是使用
fuse_modules()
API 完成的,该 API 接受要融合的模块列表。我们目前支持以下融合:[Conv, Relu]、[Conv, BatchNorm]、[Conv, BatchNorm, Relu]、[Linear, Relu]
(原型 - 维护模式)FX 图模式量化¶
训练后量化中有多种量化类型(仅权重、动态和静态),配置通过 qconfig_mapping(prepare_fx 函数的参数)完成。
FXPTQ API 示例
import torch
from torch.ao.quantization import (
get_default_qconfig_mapping,
get_default_qat_qconfig_mapping,
QConfigMapping,
)
import torch.ao.quantization.quantize_fx as quantize_fx
import copy
model_fp = UserModel()
#
# post training dynamic/weight_only quantization
#
# we need to deepcopy if we still want to keep model_fp unchanged after quantization since quantization apis change the input model
model_to_quantize = copy.deepcopy(model_fp)
model_to_quantize.eval()
qconfig_mapping = QConfigMapping().set_global(torch.ao.quantization.default_dynamic_qconfig)
# a tuple of one or more example inputs are needed to trace the model
example_inputs = (input_fp32)
# prepare
model_prepared = quantize_fx.prepare_fx(model_to_quantize, qconfig_mapping, example_inputs)
# no calibration needed when we only have dynamic/weight_only quantization
# quantize
model_quantized = quantize_fx.convert_fx(model_prepared)
#
# post training static quantization
#
model_to_quantize = copy.deepcopy(model_fp)
qconfig_mapping = get_default_qconfig_mapping("qnnpack")
model_to_quantize.eval()
# prepare
model_prepared = quantize_fx.prepare_fx(model_to_quantize, qconfig_mapping, example_inputs)
# calibrate (not shown)
# quantize
model_quantized = quantize_fx.convert_fx(model_prepared)
#
# quantization aware training for static quantization
#
model_to_quantize = copy.deepcopy(model_fp)
qconfig_mapping = get_default_qat_qconfig_mapping("qnnpack")
model_to_quantize.train()
# prepare
model_prepared = quantize_fx.prepare_qat_fx(model_to_quantize, qconfig_mapping, example_inputs)
# training loop (not shown)
# quantize
model_quantized = quantize_fx.convert_fx(model_prepared)
#
# fusion
#
model_to_quantize = copy.deepcopy(model_fp)
model_fused = quantize_fx.fuse_fx(model_to_quantize)
请按照以下教程了解有关 FX 图模式量化的更多信息
(原型)PyTorch 2 导出量化¶
API 示例
import torch
from torch.ao.quantization.quantize_pt2e import prepare_pt2e
from torch._export import capture_pre_autograd_graph
from torch.ao.quantization.quantizer import (
XNNPACKQuantizer,
get_symmetric_quantization_config,
)
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(5, 10)
def forward(self, x):
return self.linear(x)
# initialize a floating point model
float_model = M().eval()
# define calibration function
def calibrate(model, data_loader):
model.eval()
with torch.no_grad():
for image, target in data_loader:
model(image)
# Step 1. program capture
# NOTE: this API will be updated to torch.export API in the future, but the captured
# result should mostly stay the same
m = capture_pre_autograd_graph(m, *example_inputs)
# we get a model with aten ops
# Step 2. quantization
# backend developer will write their own Quantizer and expose methods to allow
# users to express how they
# want the model to be quantized
quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config())
# or prepare_qat_pt2e for Quantization Aware Training
m = prepare_pt2e(m, quantizer)
# run calibration
# calibrate(m, sample_inference_data)
m = convert_pt2e(m)
# Step 3. lowering
# lower to target backend
请按照以下教程开始使用 PyTorch 2 导出量化
建模用户
后端开发人员(也请查看所有建模用户文档)
量化堆栈¶
量化是将浮点模型转换为量化模型的过程。因此,在高层次上,量化堆栈可以分为两部分:1). 量化模型的构建块或抽象 2). 将浮点模型转换为量化模型的量化流程的构建块或抽象
量化模型¶
量化张量¶
为了在 PyTorch 中进行量化,我们需要能够用张量表示量化数据。量化张量允许存储量化数据(表示为 int8/uint8/int32),以及量化参数(如比例和零点)。量化张量允许许多有用的操作,从而使量化算术变得容易,此外还允许以量化格式序列化数据。
PyTorch 支持按张量和按通道的对称和非对称量化。按张量表示张量中的所有值都以相同的方式使用相同的量化参数进行量化。按通道表示对于每个维度(通常是张量的通道维度),张量中的值都使用不同的量化参数进行量化。这可以减少将张量转换为量化值时的误差,因为异常值只会影响它所在的通道,而不是整个张量。
映射是通过使用以下公式转换浮点张量来执行的
data:image/s3,"s3://crabby-images/c64c0/c64c0f1af384ac90d64099911987fd539f0b1086" alt="_images/math-quantizer-equation.png"
请注意,我们确保浮点中的零在量化后以无误差的方式表示,从而确保诸如填充之类的操作不会导致额外的量化误差。
以下是量化张量的一些关键属性
QScheme (torch.qscheme):一个枚举,指定我们量化张量的方式
torch.per_tensor_affine
torch.per_tensor_symmetric
torch.per_channel_affine
torch.per_channel_symmetric
dtype (torch.dtype):量化张量的数据类型
torch.quint8
torch.qint8
torch.qint32
torch.float16
量化参数(根据 QScheme 而异):所选量化方式的参数
torch.per_tensor_affine 将具有以下量化参数
scale(float)
zero_point(int)
torch.per_channel_affine 将具有以下量化参数
per_channel_scales(float 列表)
per_channel_zero_points(int 列表)
axis(int)
量化和反量化¶
模型的输入和输出是浮点张量,但量化模型中的激活是量化的,因此我们需要运算符在浮点张量和量化张量之间进行转换。
量化(float -> quantized)
torch.quantize_per_tensor(x, scale, zero_point, dtype)
torch.quantize_per_channel(x, scales, zero_points, axis, dtype)
torch.quantize_per_tensor_dynamic(x, dtype, reduce_range)
to(torch.float16)
反量化(quantized -> float)
quantized_tensor.dequantize() - 在 torch.float16 张量上调用 dequantize 会将张量转换回 torch.float
torch.dequantize(x)
量化算子/模块¶
量化算子是以量化张量作为输入并输出量化张量的算子。
量化模块是执行量化操作的 PyTorch 模块。它们通常为加权操作(如线性层和卷积层)定义。
量化引擎¶
当量化模型执行时,qengine (torch.backends.quantized.engine) 指定要用于执行的后端。务必确保 qengine 在量化激活和权重的取值范围方面与量化模型兼容。
量化流程¶
Observer 和 FakeQuantize¶
Observer 是 PyTorch 模块,用于
收集张量统计信息,例如通过 observer 的张量的最小值和最大值
并根据收集的张量统计信息计算量化参数
FakeQuantize 是 PyTorch 模块,用于
模拟网络中张量的量化(执行量化/反量化)
它可以根据来自 observer 的收集统计信息计算量化参数,也可以学习量化参数
QConfig¶
QConfig 是 Observer 或 FakeQuantize 模块类的命名元组,可以使用 qscheme、dtype 等进行配置。它用于配置应如何观察算子
算子/模块的量化配置
不同类型的 Observer/FakeQuantize
dtype
qscheme
quant_min/quant_max:可用于模拟较低精度的张量
目前支持激活和权重的配置
我们根据为给定算子或模块配置的 qconfig 插入输入/权重/输出 observer
通用量化流程¶
一般来说,流程如下
prepare(准备)
根据用户指定的 qconfig 插入 Observer/FakeQuantize 模块
calibrate/train(校准/训练)(取决于训练后量化或量化感知训练)
允许 Observer 收集统计信息或 FakeQuantize 模块学习量化参数
convert(转换)
将校准/训练后的模型转换为量化模型
量化模式有多种,可以按两种方式分类
在应用量化流程的位置方面,我们有
训练后量化(在训练后应用量化,量化参数根据样本校准数据计算)
量化感知训练(在训练期间模拟量化,以便可以使用训练数据与模型一起学习量化参数)
在如何量化算子方面,我们可以有
仅权重量化(仅权重是静态量化的)
动态量化(权重是静态量化的,激活是动态量化的)
静态量化(权重和激活都是静态量化的)
我们可以在同一量化流程中混合使用不同的算子量化方式。例如,我们可以进行训练后量化,其中同时包含静态量化和动态量化的算子。
量化支持矩阵¶
量化模式支持¶
量化模式 |
数据集要求 |
最适合 |
精度 |
备注 |
||
训练后量化 |
动态/仅权重量化 |
激活动态量化 (fp16, int8) 或不量化,权重静态量化 (fp16, int8, in4) |
无 |
LSTM、MLP、Embedding、Transformer |
良好 |
易于使用,当性能受限于权重导致的计算或内存时,接近静态量化 |
静态量化 |
激活和权重静态量化 (int8) |
校准数据集 |
CNN |
良好 |
提供最佳性能,可能对精度有很大影响,适用于仅支持 int8 计算的硬件 |
|
量化感知训练 |
动态量化 |
激活和权重均为伪量化 |
微调数据集 |
MLP、Embedding |
最佳 |
目前支持有限 |
静态量化 |
激活和权重均为伪量化 |
微调数据集 |
CNN、MLP、Embedding |
最佳 |
通常在静态量化导致精度不佳时使用,用于缩小精度差距 |
请参阅我们的 PyTorch 量化简介 博客文章,以更全面地了解这些量化类型之间的权衡。
量化流程支持¶
PyTorch 提供了两种量化模式:Eager 模式量化和 FX 图模式量化。
Eager 模式量化是一项 Beta 功能。用户需要手动执行融合,并指定量化和反量化发生的位置,并且它仅支持模块,不支持函数式。
FX 图模式量化是 PyTorch 中的自动化量化框架,目前是一项原型功能。它通过添加对函数式的支持并自动化量化过程,改进了 Eager 模式量化,尽管人们可能需要重构模型以使模型与 FX 图模式量化兼容(使用 torch.fx
进行符号追踪)。请注意,FX 图模式量化预计无法在任意模型上工作,因为模型可能无法进行符号追踪,我们将将其集成到 torchvision 等领域库中,用户将能够使用 FX 图模式量化来量化类似于受支持领域库中模型的模型。对于任意模型,我们将提供一般指南,但要使其真正发挥作用,用户可能需要熟悉 torch.fx
,尤其是在如何使模型可进行符号追踪方面。
鼓励量化新用户首先试用 FX 图模式量化,如果不起作用,用户可以尝试遵循 使用 FX 图模式量化 的指南,或回退到 Eager 模式量化。
下表比较了 Eager 模式量化和 FX 图模式量化之间的差异
Eager 模式量化 |
FX 图模式量化 |
|
发布状态 |
beta |
prototype |
算子融合 |
手动 |
自动 |
量化/反量化放置 |
手动 |
自动 |
量化模块 |
支持 |
支持 |
量化函数式/Torch 算子 |
手动 |
自动 |
自定义支持 |
有限支持 |
完全支持 |
量化模式支持 |
训练后量化:静态、动态、仅权重 量化感知训练:静态 |
训练后量化:静态、动态、仅权重 量化感知训练:静态 |
输入/输出模型类型 |
|
|
后端/硬件支持¶
硬件 |
内核库 |
Eager 模式量化 |
FX 图模式量化 |
量化模式支持 |
服务器 CPU |
fbgemm/onednn |
支持 |
全部支持 |
|
移动 CPU |
qnnpack/xnnpack |
|||
服务器 GPU |
TensorRT(早期原型) |
不支持,因为它需要图 |
支持 |
静态量化 |
如今,PyTorch 支持以下后端以高效运行量化算子
具有 AVX2 支持或更高版本的 x86 CPU(在没有 AVX2 的情况下,某些操作的实现效率低下),通过 x86 优化,由 fbgemm 和 onednn 提供(详细信息请参见 RFC)
ARM CPU(通常在移动/嵌入式设备中找到),通过 qnnpack
(早期原型)通过 fx2trt 支持 NVidia GPU 的 TensorRT(即将开源)
本机 CPU 后端注意事项¶
我们使用相同的本机 pytorch 量化算子公开 x86 和 qnnpack,因此我们需要额外的标志来区分它们。x86 和 qnnpack 的相应实现是根据 PyTorch 构建模式自动选择的,但用户可以选择通过将 torch.backends.quantization.engine 设置为 x86 或 qnnpack 来覆盖此设置。
在准备量化模型时,务必确保 qconfig 和用于量化计算的引擎与模型将在其上执行的后端匹配。qconfig 控制量化过程中使用的 observer 类型。qengine 控制在为线性层和卷积层函数及模块打包权重时,是使用 x86 还是 qnnpack 特定打包函数。例如
x86 的默认设置
# set the qconfig for PTQ
# Note: the old 'fbgemm' is still available but 'x86' is the recommended default on x86 CPUs
qconfig = torch.ao.quantization.get_default_qconfig('x86')
# or, set the qconfig for QAT
qconfig = torch.ao.quantization.get_default_qat_qconfig('x86')
# set the qengine to control weight packing
torch.backends.quantized.engine = 'x86'
qnnpack 的默认设置
# set the qconfig for PTQ
qconfig = torch.ao.quantization.get_default_qconfig('qnnpack')
# or, set the qconfig for QAT
qconfig = torch.ao.quantization.get_default_qat_qconfig('qnnpack')
# set the qengine to control weight packing
torch.backends.quantized.engine = 'qnnpack'
算子支持¶
算子覆盖率因动态量化和静态量化而异,并在下表中捕获。请注意,对于 FX 图模式量化,也支持相应的函数式。
静态量化 |
动态量化 |
|
nn.Linear
nn.Conv1d/2d/3d
|
是
是
|
是
否
|
nn.LSTM
nn.GRU
|
否
否
|
是
是
|
nn.RNNCell
nn.GRUCell
nn.LSTMCell
|
否
否
否
|
是
是
是
|
nn.EmbeddingBag |
是(激活为 fp32) |
是 |
nn.Embedding |
是 |
是 |
nn.MultiheadAttention |
不支持 |
不支持 |
激活 |
广泛支持 |
未更改,计算保持在 fp32 中 |
注意:这将很快使用来自本机 backend_config_dict 生成的一些信息进行更新。
量化自定义¶
虽然提供了基于观察到的张量数据来选择比例因子和偏差的观察者的默认实现,但开发人员可以提供他们自己的量化函数。量化可以有选择地应用于模型的不同部分,或者针对模型的不同部分进行不同的配置。
我们还为 conv1d()、conv2d()、conv3d() 和 linear() 提供 per channel 量化支持。
量化工作流程通过在模型的模块层级结构中添加(例如,添加观察者作为 .observer
子模块)或替换(例如,将 nn.Conv2d
转换为 nn.quantized.Conv2d
)子模块来工作。这意味着模型在整个过程中仍然是一个常规的基于 nn.Module
的实例,因此可以与 PyTorch API 的其余部分一起工作。
量化自定义模块 API¶
Eager 模式和 FX 图模式量化 API 都为用户提供了一个钩子,以自定义方式指定量化模块,并为观察和量化提供用户定义的逻辑。用户需要指定
源 fp32 模块(模型中已存在)的 Python 类型
被观察模块的 Python 类型(由用户提供)。此模块需要定义一个 from_float 函数,该函数定义如何从原始 fp32 模块创建被观察模块。
量化模块的 Python 类型(由用户提供)。此模块需要定义一个 from_observed 函数,该函数定义如何从被观察模块创建量化模块。
描述上述 (1)、(2)、(3) 的配置,传递给量化 API。
然后框架将执行以下操作
在 prepare 模块交换期间,它将使用 (2) 类中的 from_float 函数,将 (1) 中指定的每种类型的模块转换为 (2) 中指定的类型。
在 convert 模块交换期间,它将使用 (3) 类中的 from_observed 函数,将 (2) 中指定的每种类型的模块转换为 (3) 中指定的类型。
目前,有一个要求是 ObservedCustomModule 将具有单个 Tensor 输出,并且观察者将由框架(而不是用户)添加到该输出。观察者将存储在 activation_post_process 键下,作为自定义模块实例的属性。放宽这些限制可能会在未来进行。
自定义 API 示例
import torch
import torch.ao.nn.quantized as nnq
from torch.ao.quantization import QConfigMapping
import torch.ao.quantization.quantize_fx
# original fp32 module to replace
class CustomModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(3, 3)
def forward(self, x):
return self.linear(x)
# custom observed module, provided by user
class ObservedCustomModule(torch.nn.Module):
def __init__(self, linear):
super().__init__()
self.linear = linear
def forward(self, x):
return self.linear(x)
@classmethod
def from_float(cls, float_module):
assert hasattr(float_module, 'qconfig')
observed = cls(float_module.linear)
observed.qconfig = float_module.qconfig
return observed
# custom quantized module, provided by user
class StaticQuantCustomModule(torch.nn.Module):
def __init__(self, linear):
super().__init__()
self.linear = linear
def forward(self, x):
return self.linear(x)
@classmethod
def from_observed(cls, observed_module):
assert hasattr(observed_module, 'qconfig')
assert hasattr(observed_module, 'activation_post_process')
observed_module.linear.activation_post_process = \
observed_module.activation_post_process
quantized = cls(nnq.Linear.from_float(observed_module.linear))
return quantized
#
# example API call (Eager mode quantization)
#
m = torch.nn.Sequential(CustomModule()).eval()
prepare_custom_config_dict = {
"float_to_observed_custom_module_class": {
CustomModule: ObservedCustomModule
}
}
convert_custom_config_dict = {
"observed_to_quantized_custom_module_class": {
ObservedCustomModule: StaticQuantCustomModule
}
}
m.qconfig = torch.ao.quantization.default_qconfig
mp = torch.ao.quantization.prepare(
m, prepare_custom_config_dict=prepare_custom_config_dict)
# calibration (not shown)
mq = torch.ao.quantization.convert(
mp, convert_custom_config_dict=convert_custom_config_dict)
#
# example API call (FX graph mode quantization)
#
m = torch.nn.Sequential(CustomModule()).eval()
qconfig_mapping = QConfigMapping().set_global(torch.ao.quantization.default_qconfig)
prepare_custom_config_dict = {
"float_to_observed_custom_module_class": {
"static": {
CustomModule: ObservedCustomModule,
}
}
}
convert_custom_config_dict = {
"observed_to_quantized_custom_module_class": {
"static": {
ObservedCustomModule: StaticQuantCustomModule,
}
}
}
mp = torch.ao.quantization.quantize_fx.prepare_fx(
m, qconfig_mapping, torch.randn(3,3), prepare_custom_config=prepare_custom_config_dict)
# calibration (not shown)
mq = torch.ao.quantization.quantize_fx.convert_fx(
mp, convert_custom_config=convert_custom_config_dict)
最佳实践¶
1. 如果您正在使用 x86
后端,我们需要使用 7 位而不是 8 位。确保您减小 quant\_min
、quant\_max
的范围,例如,如果 dtype
是 torch.quint8
,请确保将自定义 quant_min
设置为 0
,并将 quant_max
设置为 127
(255
/ 2
),如果 dtype
是 torch.qint8
,请确保将自定义 quant_min
设置为 -64
(-128
/ 2
),并将 quant_max
设置为 63
(127
/ 2
),如果您调用 torch.ao.quantization.get_default_qconfig(backend) 或 torch.ao.quantization.get_default_qat_qconfig(backend) 函数来获取 x86
或 qnnpack
后端的默认 qconfig
,我们已经正确设置了这一点
2. 如果选择 onednn
后端,则默认 qconfig 映射 torch.ao.quantization.get_default_qconfig_mapping('onednn')
和默认 qconfig torch.ao.quantization.get_default_qconfig('onednn')
中将使用 8 位激活。建议在支持向量神经网络指令 (VNNI) 的 CPU 上使用。否则,将激活观察者的 reduce_range
设置为 True,以在不支持 VNNI 的 CPU 上获得更好的精度。
常见问题¶
如何在 GPU 上进行量化推理?
我们目前还没有官方的 GPU 支持,但这是一个积极开发的领域,您可以在 这里 找到更多信息
在哪里可以获得量化模型的 ONNX 支持?
如果您在导出模型时遇到错误(使用
torch.onnx
下的 API),您可以在 PyTorch 存储库中打开一个 issue。在 issue 标题前加上[ONNX]
,并将 issue 标记为module: onnx
。如果您在使用 ONNX Runtime 时遇到问题,请在 GitHub - microsoft/onnxruntime 上打开一个 issue。
如何将量化与 LSTM 一起使用?
LSTM 通过我们的自定义模块 api 在 eager 模式和 fx 图模式量化中都受支持。示例可以在 Eager 模式中找到:pytorch/test_quantized_op.py TestQuantizedOps.test_custom_module_lstm FX 图模式:pytorch/test_quantize_fx.py TestQuantizeFx.test_static_lstm
常见错误¶
将非量化 Tensor 传递到量化内核中¶
如果您看到类似于以下的错误
RuntimeError: Could not run 'quantized::some_operator' with arguments from the 'CPU' backend...
这意味着您正在尝试将非量化 Tensor 传递给量化内核。一个常见的解决方法是使用 torch.ao.quantization.QuantStub
来量化张量。这需要在 Eager 模式量化中手动完成。一个端到端示例
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.quant = torch.ao.quantization.QuantStub()
self.conv = torch.nn.Conv2d(1, 1, 1)
def forward(self, x):
# during the convert step, this will be replaced with a
# `quantize_per_tensor` call
x = self.quant(x)
x = self.conv(x)
return x
将量化 Tensor 传递到非量化内核中¶
如果您看到类似于以下的错误
RuntimeError: Could not run 'aten::thnn_conv2d_forward' with arguments from the 'QuantizedCPU' backend.
这意味着您正在尝试将量化 Tensor 传递给非量化内核。一个常见的解决方法是使用 torch.ao.quantization.DeQuantStub
来反量化张量。这需要在 Eager 模式量化中手动完成。一个端到端示例
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.quant = torch.ao.quantization.QuantStub()
self.conv1 = torch.nn.Conv2d(1, 1, 1)
# this module will not be quantized (see `qconfig = None` logic below)
self.conv2 = torch.nn.Conv2d(1, 1, 1)
self.dequant = torch.ao.quantization.DeQuantStub()
def forward(self, x):
# during the convert step, this will be replaced with a
# `quantize_per_tensor` call
x = self.quant(x)
x = self.conv1(x)
# during the convert step, this will be replaced with a
# `dequantize` call
x = self.dequant(x)
x = self.conv2(x)
return x
m = M()
m.qconfig = some_qconfig
# turn off quantization for conv2
m.conv2.qconfig = None
保存和加载量化模型¶
在量化模型上调用 torch.load
时,如果您看到类似以下的错误
AttributeError: 'LinearPackedParams' object has no attribute '_modules'
这是因为不支持使用 torch.save
和 torch.load
直接保存和加载量化模型。要保存/加载量化模型,可以使用以下方法
保存/加载量化模型的 state_dict
一个例子
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(5, 5)
self.relu = nn.ReLU()
def forward(self, x):
x = self.linear(x)
x = self.relu(x)
return x
m = M().eval()
prepare_orig = prepare_fx(m, {'' : default_qconfig})
prepare_orig(torch.rand(5, 5))
quantized_orig = convert_fx(prepare_orig)
# Save/load using state_dict
b = io.BytesIO()
torch.save(quantized_orig.state_dict(), b)
m2 = M().eval()
prepared = prepare_fx(m2, {'' : default_qconfig})
quantized = convert_fx(prepared)
b.seek(0)
quantized.load_state_dict(torch.load(b))
使用
torch.jit.save
和torch.jit.load
保存/加载脚本化量化模型
一个例子
# Note: using the same model M from previous example
m = M().eval()
prepare_orig = prepare_fx(m, {'' : default_qconfig})
prepare_orig(torch.rand(5, 5))
quantized_orig = convert_fx(prepare_orig)
# save/load using scripted model
scripted = torch.jit.script(quantized_orig)
b = io.BytesIO()
torch.jit.save(scripted, b)
b.seek(0)
scripted_quantized = torch.jit.load(b)
使用 FX 图模式量化时出现符号追踪错误¶
符号可追溯性是 (原型 - 维护模式) FX 图模式量化 的一项要求,因此如果您将不具备符号可追溯性的 PyTorch 模型传递给 torch.ao.quantization.prepare_fx 或 torch.ao.quantization.prepare_qat_fx,我们可能会看到如下错误
torch.fx.proxy.TraceError: symbolically traced variables cannot be used as inputs to control flow
请查看 符号追踪的局限性 并使用 - 关于使用 FX 图模式量化的用户指南 来解决问题。