torchrec.quant¶
Torchrec 量化
Torchrec 提供了 EmbeddingBagCollection 的量化版本,用于推理。它依赖于 fbgemm 量化操作。这减少了模型权重的尺寸并加快了模型执行速度。
示例
>>> import torch.quantization as quant
>>> import torchrec.quant as trec_quant
>>> import torchrec as trec
>>> qconfig = quant.QConfig(
>>> activation=quant.PlaceholderObserver,
>>> weight=quant.PlaceholderObserver.with_args(dtype=torch.qint8),
>>> )
>>> quantized = quant.quantize_dynamic(
>>> module,
>>> qconfig_spec={
>>> trec.EmbeddingBagCollection: qconfig,
>>> },
>>> mapping={
>>> trec.EmbeddingBagCollection: trec_quant.EmbeddingBagCollection,
>>> },
>>> inplace=inplace,
>>> )
torchrec.quant.embedding_modules¶
- class torchrec.quant.embedding_modules.EmbeddingBagCollection(tables: List[EmbeddingBagConfig], is_weighted: bool, device: device, output_dtype: dtype = torch.float32, table_name_to_quantized_weights: Optional[Dict[str, Tuple[Tensor, Tensor]]] = None, register_tbes: bool = False, quant_state_dict_split_scale_bias: bool = False, row_alignment: int = 16)¶
继承自:
EmbeddingBagCollectionInterface
,ModuleNoCopyMixin
EmbeddingBagCollection 表示池化嵌入(EmbeddingBag)的集合。此 EmbeddingBagCollection 已被量化,以降低精度。它依赖于 fbgemm 量化操作并提供表批处理。
注意
EmbeddingBagCollection 是一个未分片的模块,未针对性能进行优化。对于对性能要求较高的场景,请考虑使用分片版本 ShardedEmbeddingBagCollection。
它以 KeyedJaggedTensor 的形式处理稀疏数据,其值为 [F X B X L] F:特征(键) B:批次大小 L:稀疏特征(不规则)的长度
并输出一个 KeyedTensor,其值为 [B * (F * D)],其中 F:特征(键) D:每个特征(键)的嵌入维度 B:批次大小
- 参数:
table_name_to_quantized_weights (Dict[str, Tuple[Tensor, Tensor]]) – 表到量化权重的映射
embedding_configs (List[EmbeddingBagConfig]) – 嵌入表列表
is_weighted – (bool):输入 KeyedJaggedTensor 是否加权
device – (Optional[torch.device]):默认计算设备
- 调用参数
features: KeyedJaggedTensor,
- 返回值:
KeyedTensor
示例
table_0 = EmbeddingBagConfig( name="t1", embedding_dim=3, num_embeddings=10, feature_names=["f1"] ) table_1 = EmbeddingBagConfig( name="t2", embedding_dim=4, num_embeddings=10, feature_names=["f2"] ) ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config]) # 0 1 2 <-- batch # "f1" [0,1] None [2] # "f2" [3] [4] [5,6,7] # ^ # feature features = KeyedJaggedTensor( keys=["f1", "f2"], values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]), offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]), ) ebc.qconfig = torch.quantization.QConfig( activation=torch.quantization.PlaceholderObserver.with_args( dtype=torch.qint8 ), weight=torch.quantization.PlaceholderObserver.with_args(dtype=torch.qint8), ) qebc = QuantEmbeddingBagCollection.from_float(ebc) quantized_embeddings = qebc(features)
- property device: device¶
- embedding_bag_configs() List[EmbeddingBagConfig] ¶
- forward(features: KeyedJaggedTensor) KeyedTensor ¶
- 参数:
features (KeyedJaggedTensor) – 形如 [F X B X L] 的 KJT。
- 返回值:
KeyedTensor
- classmethod from_float(module: EmbeddingBagCollection, use_precomputed_fake_quant: bool = False) EmbeddingBagCollection ¶
- is_weighted() bool ¶
- output_dtype() dtype ¶
- training: bool¶
- class torchrec.quant.embedding_modules.EmbeddingCollection(tables: List[EmbeddingConfig], device: device, need_indices: bool = False, output_dtype: dtype = torch.float32, table_name_to_quantized_weights: Optional[Dict[str, Tuple[Tensor, Tensor]]] = None, register_tbes: bool = False, quant_state_dict_split_scale_bias: bool = False, row_alignment: int = 16)¶
Bases:
EmbeddingCollectionInterface
,ModuleNoCopyMixin
EmbeddingCollection 表示非池化嵌入的集合。
注意
EmbeddingCollection 是一个未分片的模块,没有经过性能优化。对于性能敏感的场景,请考虑使用分片版本 ShardedEmbeddingCollection。
它以 KeyedJaggedTensor 的形式处理稀疏数据,形式为 [F X B X L],其中
F:特征(键)
B:批次大小
L:稀疏特征的长度(可变)
并输出 Dict[feature (key), JaggedTensor]。每个 JaggedTensor 包含形式为 (B * L) X D 的值,其中
B:批次大小
L:稀疏特征的长度(不规则)
D:每个特征(键)的嵌入维度,长度为 L
- 参数:
tables (List[EmbeddingConfig]) – 嵌入表列表。
device (Optional[torch.device]) – 默认计算设备。
need_indices (bool) – 是否需要将索引传递给最终的查找结果字典
示例
e1_config = EmbeddingConfig( name="t1", embedding_dim=3, num_embeddings=10, feature_names=["f1"] ) e2_config = EmbeddingConfig( name="t2", embedding_dim=3, num_embeddings=10, feature_names=["f2"] ) ec = EmbeddingCollection(tables=[e1_config, e2_config]) # 0 1 2 <-- batch # 0 [0,1] None [2] # 1 [3] [4] [5,6,7] # ^ # feature features = KeyedJaggedTensor.from_offsets_sync( keys=["f1", "f2"], values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]), offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]), ) feature_embeddings = ec(features) print(feature_embeddings['f2'].values()) tensor([[-0.2050, 0.5478, 0.6054], [ 0.7352, 0.3210, -3.0399], [ 0.1279, -0.1756, -0.4130], [ 0.7519, -0.4341, -0.0499], [ 0.9329, -1.0697, -0.8095]], grad_fn=<EmbeddingBackward>)
- property device: device¶
- embedding_configs() List[EmbeddingConfig] ¶
- embedding_dim() int ¶
- embedding_names_by_table() List[List[str]] ¶
- forward(features: KeyedJaggedTensor) Dict[str, JaggedTensor] ¶
- 参数:
features (KeyedJaggedTensor) – 形如 [F X B X L] 的 KJT。
- 返回值:
Dict[str, JaggedTensor]
- classmethod from_float(module: EmbeddingCollection, use_precomputed_fake_quant: bool = False) EmbeddingCollection ¶
- need_indices() bool ¶
- output_dtype() dtype ¶
- training: bool¶
- class torchrec.quant.embedding_modules.FeatureProcessedEmbeddingBagCollection(tables: List[EmbeddingBagConfig], is_weighted: bool, device: device, output_dtype: dtype = torch.float32, table_name_to_quantized_weights: Optional[Dict[str, Tuple[Tensor, Tensor]]] = None, register_tbes: bool = False, quant_state_dict_split_scale_bias: bool = False, row_alignment: int = 16, feature_processor: Optional[FeatureProcessorsCollection] = None)¶
Bases:
EmbeddingBagCollection
- embedding_bags: nn.ModuleDict¶
- forward(features: KeyedJaggedTensor) KeyedTensor ¶
- 参数:
features (KeyedJaggedTensor) – 形如 [F X B X L] 的 KJT。
- 返回值:
KeyedTensor
- classmethod from_float(module: FeatureProcessedEmbeddingBagCollection, use_precomputed_fake_quant: bool = False) FeatureProcessedEmbeddingBagCollection ¶
- tbes: torch.nn.ModuleList¶
- training: bool¶
- torchrec.quant.embedding_modules.for_each_module_of_type_do(module: Module, module_types: List[Type[Module]], op: Callable[[Module], None]) None ¶
- torchrec.quant.embedding_modules.pruned_num_embeddings(pruning_indices_mapping: Tensor) int ¶
- torchrec.quant.embedding_modules.quant_prep_customize_row_alignment(module: Module, module_types: List[Type[Module]], row_alignment: int) None ¶
- torchrec.quant.embedding_modules.quant_prep_enable_quant_state_dict_split_scale_bias(module: Module) None ¶
- torchrec.quant.embedding_modules.quant_prep_enable_quant_state_dict_split_scale_bias_for_types(module: Module, module_types: List[Type[Module]]) None ¶
- torchrec.quant.embedding_modules.quant_prep_enable_register_tbes(module: Module, module_types: List[Type[Module]]) None ¶
- torchrec.quant.embedding_modules.quantize_state_dict(module: Module, table_name_to_quantized_weights: Dict[str, Tuple[Tensor, Tensor]], table_name_to_data_type: Dict[str, DataType], table_name_to_pruning_indices_mapping: Optional[Dict[str, Tensor]] = None) device ¶
模块内容¶
Torchrec 量化
Torchrec 提供了 EmbeddingBagCollection 的量化版本,用于推理。它依赖于 fbgemm 量化操作。这减少了模型权重的尺寸并加快了模型执行速度。
示例
>>> import torch.quantization as quant
>>> import torchrec.quant as trec_quant
>>> import torchrec as trec
>>> qconfig = quant.QConfig(
>>> activation=quant.PlaceholderObserver,
>>> weight=quant.PlaceholderObserver.with_args(dtype=torch.qint8),
>>> )
>>> quantized = quant.quantize_dynamic(
>>> module,
>>> qconfig_spec={
>>> trec.EmbeddingBagCollection: qconfig,
>>> },
>>> mapping={
>>> trec.EmbeddingBagCollection: trec_quant.EmbeddingBagCollection,
>>> },
>>> inplace=inplace,
>>> )