(原型) FX 图模式量化用户指南¶
创建日期:2021 年 8 月 20 日 | 最后更新:2023 年 12 月 12 日 | 最后验证:2024 年 11 月 05 日
作者:Jerry Zhang
FX 图模式量化需要可符号跟踪的模型。我们使用 FX 框架将可符号跟踪的 nn.Module 实例转换为 IR,并在 IR 上操作以执行量化过程。请将您关于模型符号跟踪的问题发布到 PyTorch 讨论论坛
量化仅适用于模型中可符号跟踪的部分。使用符号跟踪值的数据依赖控制流(如 if 语句/for 循环等)是一种不支持的常见模式。如果您的模型无法端到端进行符号跟踪,您有几种选择可以在模型的一部分上启用 FX 图模式量化。您可以结合使用这些选项中的任何一种
- 不可跟踪的代码无需量化
仅符号跟踪需要量化的代码
跳过对不可跟踪代码的符号跟踪
- 不可跟踪的代码需要量化
重构您的代码使其可符号跟踪
编写您自己的已观察和量化子模块
如果不可符号跟踪的代码无需量化,我们有以下两种选择来运行 FX 图模式量化
仅符号跟踪需要量化的代码¶
当整个模型不可符号跟踪但我们想要量化的子模块可符号跟踪时,我们可以仅对该子模块进行量化。
之前
class M(nn.Module):
def forward(self, x):
x = non_traceable_code_1(x)
x = traceable_code(x)
x = non_traceable_code_2(x)
return x
之后
class FP32Traceable(nn.Module):
def forward(self, x):
x = traceable_code(x)
return x
class M(nn.Module):
def __init__(self):
self.traceable_submodule = FP32Traceable(...)
def forward(self, x):
x = self.traceable_code_1(x)
# We'll only symbolic trace/quantize this submodule
x = self.traceable_submodule(x)
x = self.traceable_code_2(x)
return x
量化代码
qconfig_mapping = QConfigMapping().set_global(qconfig)
model_fp32.traceable_submodule = \
prepare_fx(model_fp32.traceable_submodule, qconfig_mapping, example_inputs)
注意:如果需要保留原始模型,您必须在调用量化 API 之前自行复制它。
跳过对不可跟踪代码的符号跟踪¶
当模块中存在某些不可跟踪的代码,且这部分代码无需量化时,我们可以将这部分代码重构为一个子模块,并跳过对该子模块的符号跟踪。
之前
class M(nn.Module):
def forward(self, x):
x = self.traceable_code_1(x)
x = non_traceable_code(x)
x = self.traceable_code_2(x)
return x
之后,不可跟踪部分移至一个模块并标记为叶节点
class FP32NonTraceable(nn.Module):
def forward(self, x):
x = non_traceable_code(x)
return x
class M(nn.Module):
def __init__(self):
...
self.non_traceable_submodule = FP32NonTraceable(...)
def forward(self, x):
x = self.traceable_code_1(x)
# we will configure the quantization call to not trace through
# this submodule
x = self.non_traceable_submodule(x)
x = self.traceable_code_2(x)
return x
量化代码
qconfig_mapping = QConfigMapping.set_global(qconfig)
prepare_custom_config_dict = {
# option 1
"non_traceable_module_name": "non_traceable_submodule",
# option 2
"non_traceable_module_class": [MNonTraceable],
}
model_prepared = prepare_fx(
model_fp32,
qconfig_mapping,
example_inputs,
prepare_custom_config_dict=prepare_custom_config_dict,
)
如果不可符号跟踪的代码需要量化,我们有以下两种选择
重构您的代码使其可符号跟踪¶
如果代码易于重构并使其可符号跟踪,我们可以重构代码并移除 Python 中不可跟踪的构造。
有关符号跟踪支持的更多信息可在此处找到:此处。
之前
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
这不可符号跟踪,因为在 x.view(*new_x_shape) 中,拆包 (unpacking) 不受支持,但是,由于 x.view 也支持列表输入,因此很容易移除拆包。
之后
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
这可以与其他方法结合使用,量化代码取决于模型。
编写您自己的已观察和量化子模块¶
如果不可跟踪的代码无法重构为可符号跟踪,例如它包含一些无法消除的循环,如 nn.LSTM,我们将需要将不可跟踪的代码重构为一个子模块(在 fx 图模式量化中,我们称之为 CustomModule),并定义该子模块的已观察和量化版本(在训练后静态量化或量化感知训练中,针对静态量化)或定义量化版本(在训练后动态和仅权重量化中)
之前
class M(nn.Module):
def forward(self, x):
x = traceable_code_1(x)
x = non_traceable_code(x)
x = traceable_code_1(x)
return x
之后
1. 将 non_traceable_code 重构到 FP32NonTraceable - 不可跟踪逻辑,封装在一个模块中
class FP32NonTraceable:
...
2. 定义 FP32NonTraceable 的已观察版本
class ObservedNonTraceable:
@classmethod
def from_float(cls, ...):
...
3. 定义 FP32NonTraceable 的静态量化版本,并定义一个类方法“from_observed”以从 ObservedNonTraceable 转换为 StaticQuantNonTraceable
class StaticQuantNonTraceable:
@classmethod
def from_observed(cls, ...):
...
# refactor parent class to call FP32NonTraceable
class M(nn.Module):
def __init__(self):
...
self.non_traceable_submodule = FP32NonTraceable(...)
def forward(self, x):
x = self.traceable_code_1(x)
# this part will be quantized manually
x = self.non_traceable_submodule(x)
x = self.traceable_code_1(x)
return x
量化代码
# post training static quantization or
# quantization aware training (that produces a statically quantized module)v
prepare_custom_config_dict = {
"float_to_observed_custom_module_class": {
"static": {
FP32NonTraceable: ObservedNonTraceable,
}
},
}
model_prepared = prepare_fx(
model_fp32,
qconfig_mapping,
example_inputs,
prepare_custom_config_dict=prepare_custom_config_dict)
校准 / 训练 (未显示)
convert_custom_config_dict = {
"observed_to_quantized_custom_module_class": {
"static": {
ObservedNonTraceable: StaticQuantNonTraceable,
}
},
}
model_quantized = convert_fx(
model_prepared,
convert_custom_config_dict)
训练后动态/仅权重量化 在这两种模式下,我们不需要观察原始模型,因此我们只需要定义量化模型
class DynamicQuantNonTraceable: # or WeightOnlyQuantMNonTraceable
...
@classmethod
def from_observed(cls, ...):
...
prepare_custom_config_dict = {
"non_traceable_module_class": [
FP32NonTraceable
]
}
# The example is for post training quantization
model_fp32.eval()
model_prepared = prepare_fx(
model_fp32,
qconfig_mapping,
example_inputs,
prepare_custom_config_dict=prepare_custom_config_dict)
convert_custom_config_dict = {
"observed_to_quantized_custom_module_class": {
"dynamic": {
FP32NonTraceable: DynamicQuantNonTraceable,
}
},
}
model_quantized = convert_fx(
model_prepared,
convert_custom_config_dict)
您还可以在 torch/test/quantization/test_quantize_fx.py
中的测试 test_custom_module_class
中找到自定义模块的示例。