快捷方式

量化 API 参考

torch.ao.quantization

此模块包含 Eager 模式的量化 API。

顶层 API

quantize

使用训练后静态量化对输入的浮点模型进行量化。

quantize_dynamic

将浮点模型转换为动态(即...

quantize_qat

执行量化感知训练并输出量化模型

prepare

准备模型的副本,用于量化校准或量化感知训练。

prepare_qat

准备模型的副本用于量化校准或量化感知训练,并将其转换为量化版本。

convert

根据 mapping 将输入模块中的子模块转换为不同的模块,通过调用目标模块类上的 from_float 方法实现。

准备模型进行量化

fuse_modules.fuse_modules

将模块列表融合为一个模块。

QuantStub

量化存根模块,在校准之前,它与观察者相同,将在 convert 中被替换为 nnq.Quantize

DeQuantStub

反量化存根模块,在校准之前,它与 identity 相同,将在 convert 中被替换为 nnq.DeQuantize

QuantWrapper

一个包装类,用于包装输入模块,添加 QuantStub 和 DeQuantStub,并在调用模块时围绕着调用量化和反量化模块。

add_quant_dequant

如果子模块具有有效的 qconfig,则将其包装在 QuantWrapper 中。请注意,此函数会原地修改模块的子模块,并且可以返回一个新的模块,该模块也包装了输入模块。

工具函数

swap_module

如果模块有量化对应物且附有 observer,则交换模块。

propagate_qconfig_

在模块层次结构中传播 qconfig,并在每个叶子模块上分配 qconfig 属性

default_eval_fn

定义默认评估函数。

torch.ao.quantization.quantize_fx

此模块包含 FX 图模式量化 API(原型)。

prepare_fx

准备模型进行训练后量化

prepare_qat_fx

准备模型进行量化感知训练

convert_fx

将校准或训练过的模型转换为量化模型

fuse_fx

融合诸如 conv+bn、conv+bn+relu 等模块,模型必须处于评估模式。

torch.ao.quantization.qconfig_mapping

此模块包含用于配置 FX 图模式量化的 QConfigMapping。

QConfigMapping

将模型操作映射到 torch.ao.quantization.QConfig

get_default_qconfig_mapping

返回训练后量化的默认 QConfigMapping。

get_default_qat_qconfig_mapping

返回量化感知训练的默认 QConfigMapping。

torch.ao.quantization.backend_config

此模块包含 BackendConfig,一个配置对象,用于定义后端如何支持量化。目前仅用于 FX 图模式量化,但我们可能会扩展 Eager 模式量化以支持它。

BackendConfig

定义给定后端上可量化的模式集合以及如何根据这些模式生成参考量化模型的配置。

BackendPatternConfig

指定给定运算符模式的量化行为的配置对象。

DTypeConfig

配置对象,指定参考模型规范中传递给量化操作作为参数的受支持数据类型,适用于输入和输出激活、权重和偏差。

DTypeWithConstraints

用于指定给定数据类型的附加约束的配置,例如量化值范围、比例值范围和固定量化参数,用于 DTypeConfig 中。

ObservationType

一个枚举,表示观察运算符/运算符模式的不同方式

torch.ao.quantization.fx.custom_config

此模块包含一些 CustomConfig 类,它们在 eager 模式和 FX 图模式量化中都使用

FuseCustomConfig

fuse_fx() 的自定义配置。

PrepareCustomConfig

prepare_fx()prepare_qat_fx() 的自定义配置。

ConvertCustomConfig

convert_fx() 的自定义配置。

StandaloneModuleConfigEntry

torch.ao.quantization.quantizer

torch.ao.quantization.pt2e (pytorch 2.0 导出实现的量化)

torch.ao.quantization.pt2e.export_utils

model_is_exported

如果 torch.nn.Module 被导出,则返回 True,否则返回 False(例如...

PT2 导出 (pt2e) 数值调试器

generate_numeric_debug_handle

为给定 ExportedProgram 的图模块中的所有节点(如 conv2d、squeeze、conv1d 等,占位符除外)附加 numeric_debug_handle_id。

CUSTOM_KEY

str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str

NUMERIC_DEBUG_HANDLE_KEY

str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str

prepare_for_propagation_comparison

为具有 numeric_debug_handle 的节点添加输出记录器

extract_results_from_loggers

对于给定的模型,提取每个调试句柄的张量统计信息和相关信息。

compare_results

给定两个将 debug_handle_id (int) 映射到张量列表的字典,返回一个将 debug_handle_id 映射到 NodeAccuracySummary 的字典,其中包含 SQNR、MSE 等比较信息。

torch.ao.quantization.observer

此模块包含观察者,用于收集校准 (PTQ) 或训练 (QAT) 期间观察到的值的统计信息。

ObserverBase

基础观察者模块。

MinMaxObserver

基于运行中的最小值和最大值计算量化参数的观察者模块。

MovingAverageMinMaxObserver

基于最小值和最大值移动平均值计算量化参数的观察者模块。

PerChannelMinMaxObserver

基于运行中的逐通道最小值和最大值计算量化参数的观察者模块。

MovingAveragePerChannelMinMaxObserver

基于运行中的逐通道最小值和最大值计算量化参数的观察者模块。

HistogramObserver

该模块记录张量值的运行直方图以及最小值/最大值。

PlaceholderObserver

一个不做任何事情的观察者,仅将其配置传递给量化模块的 .from_float() 方法。

RecordingObserver

该模块主要用于调试,并在运行时记录张量值。

NoopObserver

一个不做任何事情的观察者,仅将其配置传递给量化模块的 .from_float() 方法。

get_observer_state_dict

返回对应于观察者统计信息的 state dict。

load_observer_state_dict

给定输入模型和一个包含模型观察者统计信息的 state_dict,将统计信息加载回模型中。

default_observer

静态量化的默认观察者,通常用于调试。

default_placeholder_observer

默认占位符观察者,通常用于量化到 torch.float16。

default_debug_observer

默认仅用于调试的观察者。

default_weight_observer

默认权重观察者。

default_histogram_observer

默认直方图观察者,通常用于 PTQ。

default_per_channel_weight_observer

默认逐通道权重观察者,通常用于支持逐通道权重量化的后端,例如 fbgemm

default_dynamic_quant_observer

动态量化的默认观察者。

default_float_qparams_observer

浮点 zero-point 的默认观察者。

AffineQuantizedObserverBase

仿射量化观察者模块 (https://github.com/pytorch/ao/tree/main/torchao/quantization#affine-quantization)

Granularity

表示量化粒度的基类。

MappingType

浮点数如何映射到整数

PerAxis

表示量化中的逐轴粒度。

PerBlock

表示量化中的逐块粒度。

PerGroup

表示量化中的逐通道组粒度。

PerRow

表示量化中的逐行粒度。

PerTensor

表示量化中的逐张量粒度。

PerToken

表示量化中的逐 token 粒度。

TorchAODType

尚不存在于 PyTorch 核心中的数据类型的占位符。

ZeroPointDomain

一个枚举,指示 zero_point 是在整数域还是浮点域中

get_block_size

根据输入形状和粒度类型获取块大小。

torch.ao.quantization.fake_quantize

此模块实现了在 QAT 期间用于执行伪量化的模块。

FakeQuantizeBase

基础伪量化模块。

FakeQuantize

在训练期间模拟量化和反量化操作。

FixedQParamsFakeQuantize

在训练期间模拟量化和反量化。

FusedMovingAvgObsFakeQuantize

定义一个融合模块来观察张量。

default_fake_quant

激活的默认 fake_quant。

default_weight_fake_quant

权重的默认 fake_quant。

default_per_channel_weight_fake_quant

逐通道权重的默认 fake_quant。

default_histogram_fake_quant

使用直方图对激活进行伪量化。

default_fused_act_fake_quant

default_fake_quant 的融合版本,性能有所提升。

default_fused_wt_fake_quant

default_weight_fake_quant 的融合版本,性能有所提升。

default_fused_per_channel_wt_fake_quant

default_per_channel_weight_fake_quant 的融合版本,性能有所提升。

disable_fake_quant

禁用模块的伪量化。

enable_fake_quant

启用模块的伪量化。

disable_observer

禁用此模块的观察。

enable_observer

启用此模块的观察。

torch.ao.quantization.qconfig

此模块定义 QConfig 对象,用于配置单个操作的量化设置。

QConfig

描述如何通过分别提供激活和权重的设置(观察者类)来量化网络层或网络的一部分。

default_qconfig

默认 qconfig 配置。

default_debug_qconfig

用于调试的默认 qconfig 配置。

default_per_channel_qconfig

逐通道权重量化的默认 qconfig 配置。

default_dynamic_qconfig

默认动态 qconfig。

float16_dynamic_qconfig

权重被量化到 torch.float16 的动态 qconfig。

float16_static_qconfig

激活和权重都量化到 torch.float16 的动态 qconfig。

per_channel_dynamic_qconfig

权重逐通道量化的动态 qconfig。

float_qparams_weight_only_qconfig

权重使用浮点 zero_point 量化的动态 qconfig。

default_qat_qconfig

QAT 的默认 qconfig。

default_weight_only_qconfig

仅量化权重的默认 qconfig。

default_activation_only_qconfig

仅量化激活的默认 qconfig。

default_qat_qconfig_v2

default_qat_config 的融合版本,具有性能优势。

torch.ao.nn.intrinsic

此模块实现了可用于量化的组合(融合)模块 conv + relu。

ConvReLU1d

这是一个顺序容器,调用 Conv1d 和 ReLU 模块。

ConvReLU2d

这是一个顺序容器,调用 Conv2d 和 ReLU 模块。

ConvReLU3d

这是一个顺序容器,调用 Conv3d 和 ReLU 模块。

LinearReLU

这是一个顺序容器,调用 Linear 和 ReLU 模块。

ConvBn1d

这是一个顺序容器,调用 Conv 1d 和 Batch Norm 1d 模块。

ConvBn2d

这是一个顺序容器,调用 Conv 2d 和 Batch Norm 2d 模块。

ConvBn3d

这是一个顺序容器,调用 Conv 3d 和 Batch Norm 3d 模块。

ConvBnReLU1d

这是一个顺序容器,调用 Conv 1d、Batch Norm 1d 和 ReLU 模块。

ConvBnReLU2d

这是一个顺序容器,调用 Conv 2d、Batch Norm 2d 和 ReLU 模块。

ConvBnReLU3d

这是一个顺序容器,调用 Conv 3d、Batch Norm 3d 和 ReLU 模块。

BNReLU2d

这是一个顺序容器,调用 BatchNorm 2d 和 ReLU 模块。

BNReLU3d

这是一个顺序容器,调用 BatchNorm 3d 和 ReLU 模块。

torch.ao.nn.intrinsic.qat

此模块实现了量化感知训练所需的这些融合操作的版本。

LinearReLU

一个由 Linear 和 ReLU 模块融合而成的 LinearReLU 模块,附加了用于权重的 FakeQuantize 模块,用于量化感知训练。

ConvBn1d

ConvBn1d 模块是由 Conv1d 和 BatchNorm1d 融合而成的模块,附加了用于权重的 FakeQuantize 模块,用于量化感知训练。

ConvBnReLU1d

ConvBnReLU1d 模块是由 Conv1d、BatchNorm1d 和 ReLU 融合而成的模块,附加了用于权重的 FakeQuantize 模块,用于量化感知训练。

ConvBn2d

ConvBn2d 模块是由 Conv2d 和 BatchNorm2d 融合而成的模块,附加了用于权重的 FakeQuantize 模块,用于量化感知训练。

ConvBnReLU2d

ConvBnReLU2d 模块是由 Conv2d、BatchNorm2d 和 ReLU 融合而成的模块,附加了用于权重的 FakeQuantize 模块,用于量化感知训练。

ConvReLU2d

ConvReLU2d 模块是 Conv2d 和 ReLU 的融合模块,附加了用于权重的 FakeQuantize 模块,用于量化感知训练。

ConvBn3d

ConvBn3d 模块是由 Conv3d 和 BatchNorm3d 融合而成的模块,附加了用于权重的 FakeQuantize 模块,用于量化感知训练。

ConvBnReLU3d

ConvBnReLU3d 模块是由 Conv3d、BatchNorm3d 和 ReLU 融合而成的模块,附加了用于权重的 FakeQuantize 模块,用于量化感知训练。

ConvReLU3d

ConvReLU3d 模块是 Conv3d 和 ReLU 的融合模块,附加了用于权重的 FakeQuantize 模块,用于量化感知训练。

update_bn_stats

freeze_bn_stats

torch.ao.nn.intrinsic.quantized

此模块实现了卷积+relu等融合操作的量化实现。不包含BatchNorm变体,因为它们通常在推理时折叠到卷积中。

BNReLU2d

BNReLU2d模块是BatchNorm2d和ReLU的融合模块

BNReLU3d

BNReLU3d模块是BatchNorm3d和ReLU的融合模块

ConvReLU1d

ConvReLU1d模块是Conv1d和ReLU的融合模块

ConvReLU2d

ConvReLU2d模块是Conv2d和ReLU的融合模块

ConvReLU3d

ConvReLU3d模块是Conv3d和ReLU的融合模块

LinearReLU

LinearReLU模块由Linear和ReLU模块融合而成

torch.ao.nn.intrinsic.quantized.dynamic

此模块实现了线性+relu等融合操作的量化动态实现。

LinearReLU

一个由Linear和ReLU模块融合而成的LinearReLU模块,可用于动态量化。

torch.ao.nn.qat

此模块实现了关键nn模块(如Conv2d()Linear())的版本,它们在FP32中运行,但应用了舍入以模拟INT8量化的效果。

Conv2d

一个附带用于权重的FakeQuantize模块的Conv2d模块,用于量化感知训练。

Conv3d

一个附带用于权重的FakeQuantize模块的Conv3d模块,用于量化感知训练。

Linear

一个附带用于权重的FakeQuantize模块的Linear模块,用于量化感知训练。

torch.ao.nn.qat.dynamic

此模块实现了关键nn模块(如Linear())的版本,它们在FP32中运行,但应用了舍入以模拟INT8量化的效果,并在推理期间动态量化。

Linear

一个附带用于权重的FakeQuantize模块的Linear模块,用于动态量化感知训练。

torch.ao.nn.quantized

此模块实现了nn层(如~`torch.nn.Conv2d`和torch.nn.ReLU)的量化版本。

ReLU6

应用逐元素函数

Hardswish

这是Hardswish的量化版本。

ELU

这是ELU的量化等效项。

LeakyReLU

这是LeakyReLU的量化等效项。

Sigmoid

这是Sigmoid的量化等效项。

BatchNorm2d

这是BatchNorm2d的量化版本。

BatchNorm3d

这是BatchNorm3d的量化版本。

Conv1d

对由多个量化输入平面组成的量化输入信号应用一维卷积。

Conv2d

对由多个量化输入平面组成的量化输入信号应用二维卷积。

Conv3d

对由多个量化输入平面组成的量化输入信号应用三维卷积。

ConvTranspose1d

对由多个输入平面组成的输入图像应用一维转置卷积运算。

ConvTranspose2d

对由多个输入平面组成的输入图像应用二维转置卷积运算。

ConvTranspose3d

对由多个输入平面组成的输入图像应用三维转置卷积运算。

Embedding

一种量化Embedding模块,使用量化打包权重作为输入。

EmbeddingBag

一种量化EmbeddingBag模块,使用量化打包权重作为输入。

FloatFunctional

用于浮点运算的状态收集器类。

FXFloatFunctional

在FX图模式量化之前用于替换FloatFunctional模块的模块,因为activation_post_process将直接插入到顶层模块中。

QFunctional

量化运算的包装器类。

Linear

一种量化Linear模块,使用量化张量作为输入和输出。

LayerNorm

这是LayerNorm的量化版本。

GroupNorm

这是GroupNorm的量化版本。

InstanceNorm1d

这是InstanceNorm1d的量化版本。

InstanceNorm2d

这是InstanceNorm2d的量化版本。

InstanceNorm3d

这是InstanceNorm3d的量化版本。

torch.ao.nn.quantized.functional

函数式接口(量化)。

此模块实现了函数层(如~`torch.nn.functional.conv2d`和torch.nn.functional.relu)的量化版本。注意:relu()支持量化输入。

avg_pool2d

kH×kWkH \times kW区域应用二维平均池化运算,步长为sH×sWsH \times sW

avg_pool3d

kD timeskH×kWkD \ times kH \times kW区域应用三维平均池化运算,步长为sD×sH×sWsD \times sH \times sW

adaptive_avg_pool2d

对由多个量化输入平面组成的量化输入信号应用二维自适应平均池化。

adaptive_avg_pool3d

对由多个量化输入平面组成的量化输入信号应用三维自适应平均池化。

conv1d

对由多个输入平面组成的量化一维输入应用一维卷积。

conv2d

对由多个输入平面组成的量化二维输入应用二维卷积。

conv3d

对由多个输入平面组成的量化三维输入应用三维卷积。

interpolate

将输入下采样/上采样到给定的size或给定的scale_factor

linear

对输入的量化数据应用线性变换:y=xAT+by = xA^T + b

max_pool1d

对由多个量化输入平面组成的量化输入信号应用一维最大池化。

max_pool2d

对由多个量化输入平面组成的量化输入信号应用二维最大池化。

celu

逐元素应用量化CELU函数。

leaky_relu

的量化版本。

hardtanh

这是hardtanh()的量化版本。

hardswish

这是hardswish()的量化版本。

threshold

逐元素应用阈值函数的量化版本

elu

这是elu()的量化版本。

hardsigmoid

这是hardsigmoid()的量化版本。

clamp

float(input, min_, max_) -> Tensor

upsample

将输入上采样到给定的size或给定的scale_factor

upsample_bilinear

使用双线性上采样对输入进行上采样。

upsample_nearest

使用最近邻像素值对输入进行上采样。

torch.ao.nn.quantizable

此模块实现了一些nn层的可量化版本。这些模块可以与自定义模块机制结合使用,通过在准备和转换时提供custom_module_config参数来实现。

LSTM

一种可量化的长短期记忆(LSTM)模块。

MultiheadAttention

torch.ao.nn.quantized.dynamic

动态量化的LinearLSTMLSTMCellGRUCellRNNCell

Linear

一种动态量化的Linear模块,使用浮点张量作为输入和输出。

LSTM

一种动态量化的LSTM模块,使用浮点张量作为输入和输出。

GRU

将多层门控循环单元(GRU)RNN应用于输入序列。

RNNCell

具有tanh或ReLU非线性的Elman RNN单元。

LSTMCell

一种长短期记忆(LSTM)单元。

GRUCell

一种门控循环单元(GRU)单元

量化数据类型和量化方案

请注意,运算符实现目前仅支持对convlinear运算符权重进行逐通道量化。此外,输入数据按如下方式线性映射到量化数据,反之亦然:

Quantization:Qout=clamp(xinput/s+z,Qmin,Qmax)Dequantization:xout=(Qinputz)s\begin{aligned} \text{Quantization:}&\\ &Q_\text{out} = \text{clamp}(x_\text{input}/s+z, Q_\text{min}, Q_\text{max})\\ \text{Dequantization:}&\\ &x_\text{out} = (Q_\text{input}-z)*s \end{aligned}

其中clamp(.)\text{clamp}(.)clamp()相同,而缩放因子ss和零点zz则按MinMaxObserver中的描述进行计算,具体如下:

if Symmetric:s=2max(xmin,xmax)/(QmaxQmin)z={0if dtype is qint8128otherwiseOtherwise:s=(xmaxxmin)/(QmaxQmin)z=Qminround(xmin/s)\begin{aligned} \text{if Symmetric:}&\\ &s = 2 \max(|x_\text{min}|, x_\text{max}) / \left( Q_\text{max} - Q_\text{min} \right) \\ &z = \begin{cases} 0 & \text{if dtype is qint8} \\ 128 & \text{otherwise} \end{cases}\\ \text{Otherwise:}&\\ &s = \left( x_\text{max} - x_\text{min} \right ) / \left( Q_\text{max} - Q_\text{min} \right ) \\ &z = Q_\text{min} - \text{round}(x_\text{min} / s) \end{aligned}

其中 [xmin,xmax][x_\text{min}, x_\text{max}] 表示输入数据的范围,而 QminQ_\text{min}QmaxQ_\text{max} 分别表示量化数据类型的最小值和最大值。

请注意,sszz 的选择意味着只要零位于输入数据的范围内或正在使用对称量化时,零就可以表示而没有量化误差。

可以通过 自定义算子机制 来实现其他数据类型和量化方案。

  • torch.qscheme — 用于描述张量量化方案的类型。支持的类型

    • torch.per_tensor_affine — 每张量,非对称

    • torch.per_channel_affine — 每通道,非对称

    • torch.per_tensor_symmetric — 每张量,对称

    • torch.per_channel_symmetric — 每通道,对称

  • torch.dtype — 用于描述数据的类型。支持的类型

    • torch.quint8 — 8位无符号整数

    • torch.qint8 — 8位有符号整数

    • torch.qint32 — 32位有符号整数

QAT 模块。

此包正在被弃用。请改用 torch.ao.nn.qat.modules

QAT 动态模块。

此包正在被弃用。请改用 torch.ao.nn.qat.dynamic

此文件正在迁移到 torch/ao/quantization,并暂时保留在此处以保持兼容性。如果您要添加新的条目/功能,请将其添加到 torch/ao/quantization/fx/ 下的相应文件中,并在此处添加导入语句。

QAT 动态模块。

此包正在被弃用。请改用 torch.ao.nn.qat.dynamic

量化模块。

注意:

torch.nn.quantized 命名空间正在被弃用。请改用 torch.ao.nn.quantized

量化动态模块。

此文件正在迁移到 torch/ao/nn/quantized/dynamic,并暂时保留在此处以保持兼容性。如果您要添加新的条目/功能,请将其添加到 torch/ao/nn/quantized/dynamic 下的相应文件中,并在此处添加导入语句。

文档

访问 PyTorch 的完整开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并解答您的问题

查看资源