如何为 PyTorch 2 导出量化编写 Quantizer
¶
创建于:2023 年 7 月 28 日 | 最后更新:2024 年 8 月 1 日 | 最后验证:2024 年 11 月 05 日
作者: Leslie Fang, Weiwen Xia, Jiong Gong, Kimish Patel, Jerry Zhang
简介¶
(原型)PyTorch 2 导出后训练量化 介绍了 pytorch 2 导出量化的总体 API,API 方面与 fx 图模式量化的主要区别在于,我们明确指出量化是针对特定后端的。因此,要使用新的流程,后端需要实现一个 Quantizer
类,该类编码:(1)。后端中支持的量化算子或模式 (2)。用户如何表达他们希望如何量化浮点模型的方式,例如,将整个模型量化为 int8 对称量化,或仅量化线性层等。
请参阅 此处 了解新 API 和 Quantizer
的动机。
为 XNNPACK
定义的现有量化器对象位于 QNNPackQuantizer 中
注解 API¶
Quantizer
使用注解 API 来传达不同算子/模式的量化意图。注解 API 主要由 QuantizationSpec 和 QuantizationAnnotation 组成。
QuantizationSpec
用于传达张量将如何量化的意图,例如 dtype、位宽、最小值、最大值、对称与非对称等。此外,QuantizationSpec
还允许量化器指定应如何观察张量值,例如 MinMaxObserver
或 HistogramObserver
或某些自定义观察者。
QuantizationAnnotation
由 QuantizationSpec
对象组成,用于注解模式的输入张量和输出张量。注解输入张量等效于注解输入边,而注解输出张量等效于注解节点。QuantizationAnnotation
是一个带有多个字段的 dataclass
input_qspec_map
字段是Dict
类,用于将每个输入张量(作为输入边)映射到QuantizationSpec
。output_qspec
字段表示用于注解输出张量的QuantizationSpec
;_annotated
字段指示此节点是否已被量化器注解。
总之,注解 API 要求量化器注解图的边(输入张量)或节点(输出张量)。现在,我们将提供一个逐步教程,介绍如何将注解 API 与不同类型的 QuantizationSpec
一起使用。
1. 注解常用算子模式¶
为了使用量化模式/算子,例如 quantized add
,后端开发人员将有意量化(如 QuantizationSpec
所表达的)模式的输入、输出。以下是一个示例流程(以 add
算子为例),说明此意图如何在带有注解 API 的量化工作流程中传达。
步骤 1:识别 FX 图中的原始浮点模式。有几种方法可以识别此模式:量化器可以使用模式匹配器来匹配算子模式;量化器可以从头到尾遍历节点,并将节点的target类型与算子模式进行匹配。在本例中,我们可以使用 get_source_partitions 来匹配此模式。原始浮点
add
模式仅包含一个add
节点。
add_partitions = get_source_partitions(gm.graph, [operator.add, torch.add])
add_partitions = list(itertools.chain(*add_partitions.values()))
for add_partition in add_partitions:
add_node = add_partition.output_nodes[0]
步骤 2:为模式的输入和输出定义
QuantizationSpec
。QuantizationSpec
定义了数据 类型
、qscheme
和其他量化参数,这些参数是关于用户如何观察或伪量化张量的意图。
act_quantization_spec = QuantizationSpec(
dtype=torch.int8,
quant_min=-128,
quant_max=127,
qscheme=torch.per_tensor_affine,
is_dynamic=False,
observer_or_fake_quant_ctr=HistogramObserver.with_args(eps=2**-12),
)
input_act_qspec = act_quantization_spec
output_act_qspec = act_quantization_spec
步骤 3:使用
QuantizationAnnotation
注解模式的输入和输出。在本例中,我们将使用在上述步骤 2 中创建的QuantizationSpec
对象为add
节点的两个输入和一个输出创建QuantizationAnnotation
对象。
input_qspec_map = {}
input_act0 = add_node.args[0]
input_qspec_map[input_act0] = input_act_qspec
input_act1 = add_node.args[1]
input_qspec_map[input_act1] = input_act_qspec
add_node.meta["quantization_annotation"] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
output_qspec=output_act_qspec,
_annotated=True,
)
在我们像这样注解 add
节点之后,在后续的量化流程中,将在准备阶段在其两个输入节点和一个输出节点处插入 HistogramObserver
。并且在转换阶段,HistogramObserver
将被替换为 quantize
节点和 dequantize
节点。
3. 注解具有固定量化参数的算子¶
注解量化模型的另一个典型用例是针对其量化参数预先已知的张量。例如,像 sigmoid
这样的算子,其输入和输出张量具有预定义和固定的 scale/zero_point。FixedQParamsQuantizationSpec 专为此用例而设计。要使用 FixedQParamsQuantizationSpec
,用户需要显式传入 scale
和 zero_point
的参数。
步骤 1:识别 FX 图中的原始浮点模式。我们可以使用
QuantizationSpec
示例中介绍的相同方法来识别sigmoid
模式。步骤 2:创建
FixedQParamsQuantizationSpec
对象,其输入为固定的scale
、zero_point
值。这些值将用于在转换阶段创建quantize
节点和dequantize
节点。步骤 3:注解输入和输出以使用此
FixedQParamsQuantizationSpec
对象。
act_qspec = FixedQParamsQuantizationSpec(
dtype=torch.uint8,
quant_min=0,
quant_max=255,
qscheme=torch.per_tensor_affine,
scale=1.0 / 256.0,
zero_point=0,
)
sigmoid_node.meta["quantization_annotation"] = QuantizationAnnotation(
input_qspec_map={input_act: act_qspec},
output_qspec=act_qspec,
_annotated=True,
)
4. 注解具有派生量化参数的张量¶
另一个用例是定义张量的约束,这些张量的量化参数是从其他张量派生的。例如,如果我们想注解一个卷积节点,并将它的偏置输入张量的 scale
定义为激活张量的 scale
和权重张量的 scale
的乘积。我们可以使用 DerivedQuantizationSpec 来注解此 conv 节点。
步骤 1:识别 FX 图中的原始浮点模式。我们可以使用
QuantizationSpec
示例中介绍的相同方法来识别convolution
模式。步骤 2:定义
derive_qparams_fn
函数,它接受ObserverOrFakeQuantize
列表(ObserverBase 或 FakeQuantizeBase)作为输入。从每个ObserverOrFakeQuantize
对象,用户可以获取scale
、zero point
值。用户可以定义其关于如何基于从观察者或伪量化实例计算的量化参数派生新的scale
、zero point
值的启发式方法。步骤 3:定义
DerivedQuantizationSpec
对象,它接受以下输入:EdgeOrNode
对象列表。与每个EdgeOrNode
对象对应的观察者将传递到derive_qparams_fn
函数;derive_qparams_fn
函数;几个其他量化参数,例如dtype
、qscheme
。步骤 4:使用
QuantizationAnnotation
注解此 conv 节点的输入和输出。
def derive_qparams_fn(obs_or_fqs: List[ObserverOrFakeQuantize]) -> Tuple[Tensor, Tensor]:
assert len(obs_or_fqs) == 2, \
"Expecting two obs/fqs, one for activation and one for weight, got: {}".format(len(obs_or_fq))
act_obs_or_fq = obs_or_fqs[0]
weight_obs_or_fq = obs_or_fqs[1]
act_scale, act_zp = act_obs_or_fq.calculate_qparams()
weight_scale, weight_zp = weight_obs_or_fq.calculate_qparams()
return torch.tensor([act_scale * weight_scale]).to(torch.float32), torch.tensor([0]).to(torch.int32)
bias_qspec = DerivedQuantizationSpec(
derived_from=[(input_act, node), (weight, node)],
derive_qparams_fn=derive_qparams_fn,
dtype=torch.int32,
quant_min=-2**31,
quant_max=2**31 - 1,
qscheme=torch.per_tensor_symmetric,
)
input_qspec_map = {input_act: act_quantization_spec, weight: weight_quantization_spec, bias: bias_qspec}
node.meta["quantization_annotation"] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
output_qspec=act_quantization_spec,
_annotated=True,
)
5. 使用 Resnet18 的玩具示例¶
在使用 QuantizationAnnotation API
定义了上述注解方法之后,我们现在可以将它们放在一起以构建一个 BackendQuantizer
,并使用 Torchvision Resnet18
运行一个 玩具示例。为了更好地理解最终示例,以下是示例中使用的类和实用函数
QuantizationConfig 由分别用于激活、权重和偏置的
QuantizationSpec
组成。在注解模型时,可以使用 get_input_act_qspec、get_output_act_qspec、get_weight_qspec 和 get_bias_qspec 从
QuantizationConfig
中获取特定模式的QuantizationSpec
。
关于 PT2E 量化流程的 IR 的说明¶
IR 指的是模型的中间表示,例如,torch
IR (torch.nn
模块, torch.nn.functional
操作) 或 aten
IR (torch.ops.aten.linear
, …)。 PT2E 量化流程使用预 autograd aten IR ( torch.export API 的输出),以便我们支持训练。 如前所示,我们需要匹配运算符或运算符模式,然后才能在其上附加注释。 因此,问题是我们如何匹配模式?
动机:直接匹配 aten
IR 的问题¶
最直接的方法可能是直接匹配 aten
IR。
示例
for n in gm.graph.nodes:
if n.op != "call_function" or n.target not in [
torch.ops.aten.relu.default,
torch.ops.aten.relu_.default,
]:
continue
relu_node = n
maybe_conv_node = n.args[0]
if (
not isinstance(maybe_conv_node, Node)
or maybe_conv_node.op != "call_function"
or maybe_conv_node.target
not in [
torch.ops.aten.conv1d.default,
torch.ops.aten.conv2d.default,
]
):
continue
# annotate conv and relu nodes
...
然而,使用这种 IR 的一个问题是,如果模块或函数操作的 PyTorch 实现发生更改,则表示可能会更改。 但这可能是意想不到的,因为建模用户通常假设,当 eager 模式模型代码没有更改时,他们也应该在程序捕获后获得相同的模型表示。 此问题的一个具体影响是,如果 Quantizer
基于识别 aten
IR 模式进行注释,那么在 PyTorch 版本更新后,它可能无法识别该模式,并且相同的 eager 模式浮点模型可能会保持未量化状态。
建议:使用 SubgraphMatcherWithNameNodeMap
进行模式匹配¶
因此,我们建议人们通过捕获 torch
IR 模式(使用与捕获浮点模型相同的程序捕获),而不是直接使用 aten
IR 模式,通过 SubgraphMatcherWithNameNodeMap
(SubgraphMatcher
的改进版本,使其更容易查询人们想要注释的节点)来识别模式。
示例
def conv_relu_pattern(input, weight, bias):
conv = torch.nn.functional.conv2d(input, weight, bias)
output = torch.nn.functional.relu(conv)
# returns an additional dict that includes a map from name to node that we want to annotate
return relu, {"input": input, "weight": weight, "bias": bias, "output": output}
matcher = SubgraphMatcherWithNameNodeMap(conv_relu_pattern)
matches = matcher.match(model)
for match in matches:
# find input and output of the pattern
# annotate the nodes
name_node_map = match.name_node_map
input_node = name_node_map["input"]
weight_node = name_node_map["weight"]
bias_node = name_node_map["bias"]
output_node = name_node_map["relu"]
input_node.users[0].meta["quantization_annotation"] = ...
weight_node.users[0].meta["quantization_annotation"] = ...
bias_node.users[0].meta["quantization_annotation"] = ...
output_node.meta["quantization_annotation"] = ...
通过这种方式,即使 nn 模块和函数的实现发生更改,Quantizer
仍然有效,浮点模型的 aten
IR 将会更改,但由于我们再次捕获模式而不是硬编码模式的 aten
IR,我们将获得更新后的 aten
IR,并且仍然能够匹配模式。
一个需要注意的地方是,如果模式的输入有多个用户,我们没有很好的方法来识别我们想要注释哪个用户节点,除非检查 aten 操作目标。
另一个需要注意的地方是,我们需要确保我们有一个详尽的示例列表(例如,2D、3D、4D 输入,实数与符号输入,training=True 与 training=False 等),以便模式涵盖从 torch
IR 模式捕获的不同可能的 aten
IR 结果。
注意:我们可能会提供一些(模式,example_inputs 列表)或一些预生成的 matcher 对象,以便人们将来可以直接使用它们。
结论¶
通过本教程,我们介绍了 PyTorch 2 中的新量化路径。 用户可以学习如何使用 QuantizationAnnotation API
定义 BackendQuantizer
,并将其集成到 PyTorch 2 Export 量化流程中。 给出了 QuantizationSpec
、SharedQuantizationSpec
、FixedQParamsQuantizationSpec
和 DerivedQuantizationSpec
的示例,用于特定的注释用例。 您可以使用 XNNPACKQuantizer 作为示例,开始实现您自己的 Quantizer
。 之后,请按照 本教程 实际量化您的模型。