torch.ao.ns._numeric_suite¶
警告
此模块是早期原型,可能会发生更改。
- torch.ao.ns._numeric_suite.compare_weights(float_dict, quantized_dict)[源代码][源代码]¶
比较浮点模块与其对应的量化模块的权重。返回一个字典,键对应模块名称,每个条目都是一个字典,包含两个键 ‘float’ 和 ‘quantized’,分别包含浮点权重和量化权重。此字典可用于比较和计算浮点模型和量化模型的权重量化误差。
使用示例
wt_compare_dict = compare_weights( float_model.state_dict(), qmodel.state_dict()) for key in wt_compare_dict: print( key, compute_error( wt_compare_dict[key]['float'], wt_compare_dict[key]['quantized'].dequantize() ) )
- 参数
float_dict (Dict[str, Any]) – 浮点模型的状态字典
quantized_dict (Dict[str, Any]) – 量化模型的状态字典
- 返回
字典,键对应模块名称,每个条目都是一个字典,包含两个键 ‘float’ 和 ‘quantized’,分别包含浮点权重和量化权重
- 返回类型
weight_dict
- torch.ao.ns._numeric_suite.get_logger_dict(mod, prefix='')[源代码][源代码]¶
遍历模块并将所有 logger 统计信息保存到目标字典中。这主要用于量化精度调试。
- 支持的 logger 类型
ShadowLogger: 用于记录量化模块及其匹配的浮点影子模块的输出, OutputLogger: 用于记录模块的输出
- 参数
mod (Module) – 我们要保存所有 logger 统计信息的模块
prefix (str) – 当前模块的前缀
- 返回
用于保存所有 logger 统计信息的字典
- 返回类型
target_dict
- class torch.ao.ns._numeric_suite.Shadow(q_module, float_module, logger_cls)[源代码][源代码]¶
Shadow 模块将浮点模块作为影子附加到其匹配的量化模块。然后,它使用 Logger 模块来处理两个模块的输出。
- 参数
q_module – 从 float_module 量化而来的模块,我们要对其进行影子处理
float_module – 用于影子 q_module 的浮点模块
logger_cls – 用于处理 q_module 和 float_module 输出的 logger 类型。可以使用 ShadowLogger 或自定义 logger。
- torch.ao.ns._numeric_suite.prepare_model_with_stubs(float_module, q_module, module_swap_list, logger_cls)[源代码][源代码]¶
通过将浮点模块作为影子附加到其匹配的量化模块来准备模型,如果浮点模块类型在 module_swap_list 中。
使用示例
prepare_model_with_stubs(float_model, q_model, module_swap_list, Logger) q_model(data) ob_dict = get_logger_dict(q_model)
- torch.ao.ns._numeric_suite.compare_model_stub(float_model, q_model, module_swap_list, *data, logger_cls=<class 'torch.ao.ns._numeric_suite.ShadowLogger'>)[源代码][源代码]¶
将模型中的量化模块与其浮点对应模块进行比较,并向它们提供相同的输入。返回一个字典,键对应模块名称,每个条目都是一个字典,包含两个键 ‘float’ 和 ‘quantized’,分别包含量化模块及其匹配的浮点影子模块的输出张量。此字典可用于比较和计算模块级别的量化误差。
此函数首先调用 prepare_model_with_stubs() 来交换我们要比较的量化模块和 Shadow 模块,Shadow 模块接受量化模块、相应的浮点模块和 logger 作为输入,并在内部创建转发路径,使浮点模块能够影子化共享相同输入的量化模块。logger 可以自定义,默认 logger 是 ShadowLogger,它将保存量化模块和浮点模块的输出,这些输出可用于计算模块级别的量化误差。
使用示例
module_swap_list = [torchvision.models.quantization.resnet.QuantizableBasicBlock] ob_dict = compare_model_stub(float_model,qmodel,module_swap_list, data) for key in ob_dict: print(key, compute_error(ob_dict[key]['float'], ob_dict[key]['quantized'].dequantize()))
- torch.ao.ns._numeric_suite.get_matching_activations(float_module, q_module)[源代码][源代码]¶
查找浮点模块和量化模块之间匹配的激活。
- torch.ao.ns._numeric_suite.prepare_model_outputs(float_module, q_module, logger_cls=<class 'torch.ao.ns._numeric_suite.OutputLogger'>, allow_list=None)[源代码][源代码]¶
通过将 logger 附加到浮点模块和量化模块(如果它们在 allow_list 中)来准备模型。
- torch.ao.ns._numeric_suite.compare_model_outputs(float_model, q_model, *data, logger_cls=<class 'torch.ao.ns._numeric_suite.OutputLogger'>, allow_list=None)[源代码][源代码]¶
比较浮点模型和量化模型在相同输入下对应位置的输出激活。返回一个字典,键对应量化模块名称,每个条目都是一个字典,包含两个键 ‘float’ 和 ‘quantized’,分别包含量化模型和浮点模型在匹配位置的激活。此字典可用于比较和计算传播量化误差。
使用示例
act_compare_dict = compare_model_outputs(float_model, qmodel, data) for key in act_compare_dict: print( key, compute_error( act_compare_dict[key]['float'], act_compare_dict[key]['quantized'].dequantize() ) )
- 参数
- 返回
字典,键对应量化模块名称,每个条目都是一个字典,包含两个键 ‘float’ 和 ‘quantized’,分别包含匹配的浮点激活和量化激活
- 返回类型
act_compare_dict