表格批量嵌入 (TBE) 推理模块¶
稳定 API¶
- class fbgemm_gpu.split_table_batched_embeddings_ops_inference.IntNBitTableBatchedEmbeddingBagsCodegen(embedding_specs: List[Tuple[str, int, int, SparseType, EmbeddingLocation]], feature_table_map: List[int] | None = None, index_remapping: List[Tensor] | None = None, pooling_mode: PoolingMode = PoolingMode.SUM, device: str | device | int | None = None, bounds_check_mode: BoundsCheckMode = BoundsCheckMode.WARNING, weight_lists: List[Tuple[Tensor, Tensor | None]] | None = None, pruning_hash_load_factor: float = 0.5, use_array_for_index_remapping: bool = True, output_dtype: SparseType = SparseType.FP16, cache_algorithm: CacheAlgorithm = CacheAlgorithm.LRU, cache_load_factor: float = 0.2, cache_sets: int = 0, cache_reserved_memory: float = 0.0, enforce_hbm: bool = False, record_cache_metrics: RecordCacheMetrics | None = None, gather_uvm_cache_stats: bool | None = False, row_alignment: int | None = None, fp8_exponent_bits: int | None = None, fp8_exponent_bias: int | None = None, cache_assoc: int = 32, scale_bias_size_in_bytes: int = 4, cacheline_alignment: bool = True, uvm_host_mapped: bool = False, reverse_qparam: bool = False, feature_names_per_table: List[List[str]] | None = None, indices_dtype: dtype = torch.int32)[source]¶
nn.EmbeddingBag(sparse=False) 的表格批量版本,推理版本,支持 FP32/FP16/FP8/INT8/INT4/INT2 权重
- 参数:
embedding_specs (List[Tuple[int, int, EmbeddingLocation, ComputeDevice]]) –
嵌入规格列表。每个规格描述一个物理嵌入表的规范。每个规格是一个元组,包含嵌入行数、嵌入维度(必须是 4 的倍数)、表格放置位置 (EmbeddingLocation) 和计算设备 (ComputeDevice)。
可用的 EmbeddingLocation 选项包括:
DEVICE = 将嵌入表放置在 GPU 全局内存 (HBM) 中
MANAGED = 将嵌入放置在统一虚拟内存中(GPU 和 CPU 均可访问)
MANAGED_CACHING = 将嵌入表放置在统一虚拟内存中,并使用 GPU 全局内存 (HBM) 作为缓存
HOST = 将嵌入表放置在 CPU 内存 (DRAM) 中
MTIA = 将嵌入表放置在 MTIA 内存中
可用的 ComputeDevice 选项包括:
CPU = 在 CPU 上执行表格查找
CUDA = 在 GPU 上执行表格查找
MTIA = 在 MTIA 上执行表格查找
feature_table_map (Optional[List[int]] = None) – 一个可选列表,用于指定特征到表格的映射。feature_table_map[i] 指示特征 i 映射到的物理嵌入表。
index_remapping (Optional[List[Tensor]] = None) – 用于剪枝的索引重映射
pooling_mode (PoolingMode = PoolingMode.SUM) –
池化模式。可用的 PoolingMode 选项包括:
SUM = 求和池化
MEAN = 平均池化
NONE = 无池化(序列嵌入)
device (Optional[Union[str, int, torch.device]] = None) – 当前放置张量的设备
bounds_check_mode (BoundsCheckMode = BoundsCheckMode.WARNING) –
输入检查模式。可用的 BoundsCheckMode 选项包括:
NONE = 跳过边界检查
FATAL = 当遇到无效索引/偏移量时抛出错误
WARNING = 当遇到无效索引/偏移量时打印警告消息并修复它(将无效索引设置为零,并将无效偏移量调整到界限内)
IGNORE = 静默修复无效索引/偏移量(将无效索引设置为零,并将无效偏移量调整到界限内)
weight_lists (可选[List[Tuple[Tensor, Optional[Tensor]]]] = None) – [T]
pruning_hash_load_factor (float = 0.5) – 剪枝哈希的加载因子
use_array_for_index_remapping (bool = True) – 如果为 True,则使用数组进行索引重映射。否则,使用哈希映射。
output_dtype (SparseType = SparseType.FP16) – 输出张量的数据类型。
cache_algorithm (CacheAlgorithm = CacheAlgorithm.LRU) –
缓存算法(当 EmbeddingLocation 设置为 MANAGED_CACHING 时使用)。选项包括
LRU = 最近最少使用
LFU = 最不常用
cache_load_factor (float = 0.2) – 当使用 EmbeddingLocation.MANAGED_CACHING 时,用于确定缓存容量的因子。缓存容量为 cache_load_factor * 所有嵌入表中的总行数
cache_sets (int = 0) – 缓存集的数量(当 EmbeddingLocation 设置为 MANAGED_CACHING 时使用)
cache_reserved_memory (float = 0.0) – 在 HBM 中为非缓存目的保留的内存量(当 EmbeddingLocation 设置为 MANAGED_CACHING 时使用)。
enforce_hbm (bool = False) – 如果为 True,则在使用 EmbeddingLocation.MANAGED_CACHING 时将所有权重/动量放置在 HBM 中
record_cache_metrics (Optional[RecordCacheMetrics] = None) – 如果 RecordCacheMetrics.record_cache_miss_counter 为 True,则记录命中次数、请求次数等;如果 RecordCacheMetrics.record_tablewise_cache_miss is True,则按表记录类似的指标
gather_uvm_cache_stats (Optional[bool] = False) – 如果为 True,则当 EmbeddingLocation 设置为 MANAGED_CACHING 时收集缓存统计信息
row_alignment (Optional[int] = None) – 行对齐
fp8_exponent_bits (Optional[int] = None) – 使用 FP8 时的指数位
fp8_exponent_bias (Optional[int] = None) – 使用 FP8 时的指数偏差
cache_assoc (int = 32) – 缓存的路数
scale_bias_size_in_bytes (int = DEFAULT_SCALE_BIAS_SIZE_IN_BYTES) – 缩放和偏差的大小(字节)
cacheline_alignment (bool = True) – 如果为 True,则将每个表对齐到 128b 缓存行边界
uvm_host_mapped (bool = False) – 如果为 True,则使用 malloc + cudaHostRegister 分配每个 UVM 张量。否则使用 cudaMallocManaged
reverse_qparam (bool = False) – 如果为 True,则在每行末尾加载 qparams。否则,在每行开头加载 qparams。
feature_names_per_table (Optional[List[List[str]]] = None) – 一个可选列表,指定每个表的特征名称。feature_names_per_table[t] 指示表 t 的特征名称。
indices_dtype (torch.dtype = torch.int32) – 将传递给 forward() 调用的索引张量的预期 dtype。此信息将用于构建 remap_indices 数组/哈希。选项包括 torch.int32 和 torch.int64。
- assign_embedding_weights(q_weight_list: List[Tuple[Tensor, Tensor | None]]) None [source]¶
使用来自输入权重和 scale_shifts 列表的值分配 self.split_embedding_weights()。
- forward(indices: Tensor, offsets: Tensor, per_sample_weights: Tensor | None = None) Tensor [source]¶
定义每次调用时执行的计算。
应由所有子类重写。
注意
虽然前向传递的配方需要在该函数内定义,但应该在此之后调用
Module
实例而不是此函数,因为前者负责运行注册的钩子,而后者会静默地忽略它们。
- recompute_module_buffers() None [source]¶
计算元设备上的模块缓冲区,这些缓冲区在 reset_weights_placements_and_offsets() 中未实现。目前这些缓冲区是 weights_tys、rows_per_table、D_offsets 和 bounds_check_warning。剪枝相关的或 uvm 相关的缓冲区目前未计算。