量化¶
- torchao.quantization.quantize_(model: Module, apply_tensor_subclass: Callable[[Module], Module], filter_fn: Optional[Callable[[Module, str], bool]] = None, set_inductor_config: bool = True, device: Optional[Union[device, str, int]] = None)[源代码]¶
使用 apply_tensor_subclass 转换模型中线性模块的权重,模型将原地修改
- 参数:
model (torch.nn.Module) – 输入模型
apply_tensor_subclass (Callable[[torch.nn.Module], torch.nn.Module]) – 将张量子类转换应用于模块权重并返回模块的函数(例如,将线性的权重张量转换为仿射量化张量)
filter_fn (Optional[Callable[[torch.nn.Module, str], bool]]) – 接受 nn.Module 实例和模块的完全限定名称的函数,如果我们要对 apply_tensor_subclass 运行,则返回 True
module (权重的) –
set_inductor_config (bool, optional) – 是否自动使用推荐的 inductor 配置设置(默认为 True)
device (device, optional) – 在应用 filter_fn 之前将模块移动到的设备。可以设置为 “cuda” 以加速量化。最终模型将位于指定的 device 上。默认为 None(不更改设备)。
示例
import torch import torch.nn as nn from torchao import quantize_ # 1. quantize with some predefined `apply_tensor_subclass` method that corresponds to # optimized execution paths or kernels (e.g. int4 tinygemm kernel) # also customizable with arguments # currently options are # int8_dynamic_activation_int4_weight (for executorch) # int8_dynamic_activation_int8_weight (optimized with int8 mm op and torch.compile) # int4_weight_only (optimized with int4 tinygemm kernel and torch.compile) # int8_weight_only (optimized with int8 mm op and torch.compile from torchao.quantization.quant_api import int4_weight_only m = nn.Sequential(nn.Linear(32, 1024), nn.Linear(1024, 32)) quantize_(m, int4_weight_only(group_size=32)) # 2. write your own new apply_tensor_subclass # You can also add your own apply_tensor_subclass by manually calling tensor subclass constructor # on weight from torchao.dtypes import to_affine_quantized_intx # weight only uint4 asymmetric groupwise quantization groupsize = 32 apply_weight_quant = lambda x: to_affine_quantized_intx( x, "asymmetric", (1, groupsize), torch.int32, 0, 15, 1e-6, zero_point_dtype=torch.bfloat16, preserve_zero=False, zero_point_domain="float") def apply_weight_quant_to_linear(linear): linear.weight = torch.nn.Parameter(apply_weight_quant(linear.weight), requires_grad=False) return linear # apply to modules under block0 submodule def filter_fn(module: nn.Module, fqn: str) -> bool: return isinstance(module, nn.Linear) m = nn.Sequential(nn.Linear(32, 1024), nn.Linear(1024, 32)) quantize_(m, apply_weight_quant_to_linear, filter_fn)