• 文档 >
  • 表格批量嵌入 (TBE) 推理模块
快捷方式

表格批量嵌入 (TBE) 推理模块

稳定 API

class fbgemm_gpu.split_table_batched_embeddings_ops_inference.IntNBitTableBatchedEmbeddingBagsCodegen(embedding_specs: List[Tuple[str, int, int, SparseType, EmbeddingLocation]], feature_table_map: List[int] | None = None, index_remapping: List[Tensor] | None = None, pooling_mode: PoolingMode = PoolingMode.SUM, device: str | device | int | None = None, bounds_check_mode: BoundsCheckMode = BoundsCheckMode.WARNING, weight_lists: List[Tuple[Tensor, Tensor | None]] | None = None, pruning_hash_load_factor: float = 0.5, use_array_for_index_remapping: bool = True, output_dtype: SparseType = SparseType.FP16, cache_algorithm: CacheAlgorithm = CacheAlgorithm.LRU, cache_load_factor: float = 0.2, cache_sets: int = 0, cache_reserved_memory: float = 0.0, enforce_hbm: bool = False, record_cache_metrics: RecordCacheMetrics | None = None, gather_uvm_cache_stats: bool | None = False, row_alignment: int | None = None, fp8_exponent_bits: int | None = None, fp8_exponent_bias: int | None = None, cache_assoc: int = 32, scale_bias_size_in_bytes: int = 4, cacheline_alignment: bool = True, uvm_host_mapped: bool = False, reverse_qparam: bool = False, feature_names_per_table: List[List[str]] | None = None, indices_dtype: dtype = torch.int32)[source]

nn.EmbeddingBag(sparse=False) 的表格批量版本,推理版本,支持 FP32/FP16/FP8/INT8/INT4/INT2 权重

参数:
  • embedding_specs (List[Tuple[int, int, EmbeddingLocation, ComputeDevice]]) –

    嵌入规格列表。每个规格描述一个物理嵌入表的规范。每个规格是一个元组,包含嵌入行数、嵌入维度(必须是 4 的倍数)、表格放置位置 (EmbeddingLocation) 和计算设备 (ComputeDevice)。

    可用的 EmbeddingLocation 选项包括:

    1. DEVICE = 将嵌入表放置在 GPU 全局内存 (HBM) 中

    2. MANAGED = 将嵌入放置在统一虚拟内存中(GPU 和 CPU 均可访问)

    3. MANAGED_CACHING = 将嵌入表放置在统一虚拟内存中,并使用 GPU 全局内存 (HBM) 作为缓存

    4. HOST = 将嵌入表放置在 CPU 内存 (DRAM) 中

    5. MTIA = 将嵌入表放置在 MTIA 内存中

    可用的 ComputeDevice 选项包括:

    1. CPU = 在 CPU 上执行表格查找

    2. CUDA = 在 GPU 上执行表格查找

    3. MTIA = 在 MTIA 上执行表格查找

  • feature_table_map (Optional[List[int]] = None) – 一个可选列表,用于指定特征到表格的映射。feature_table_map[i] 指示特征 i 映射到的物理嵌入表。

  • index_remapping (Optional[List[Tensor]] = None) – 用于剪枝的索引重映射

  • pooling_mode (PoolingMode = PoolingMode.SUM) –

    池化模式。可用的 PoolingMode 选项包括:

    1. SUM = 求和池化

    2. MEAN = 平均池化

    3. NONE = 无池化(序列嵌入)

  • device (Optional[Union[str, int, torch.device]] = None) – 当前放置张量的设备

  • bounds_check_mode (BoundsCheckMode = BoundsCheckMode.WARNING) –

    输入检查模式。可用的 BoundsCheckMode 选项包括:

    1. NONE = 跳过边界检查

    2. FATAL = 当遇到无效索引/偏移量时抛出错误

    3. WARNING = 当遇到无效索引/偏移量时打印警告消息并修复它(将无效索引设置为零,并将无效偏移量调整到界限内)

    4. IGNORE = 静默修复无效索引/偏移量(将无效索引设置为零,并将无效偏移量调整到界限内)

  • weight_lists (可选[List[Tuple[Tensor, Optional[Tensor]]]] = None) – [T]

  • pruning_hash_load_factor (float = 0.5) – 剪枝哈希的加载因子

  • use_array_for_index_remapping (bool = True) – 如果为 True,则使用数组进行索引重映射。否则,使用哈希映射。

  • output_dtype (SparseType = SparseType.FP16) – 输出张量的数据类型。

  • cache_algorithm (CacheAlgorithm = CacheAlgorithm.LRU) –

    缓存算法(当 EmbeddingLocation 设置为 MANAGED_CACHING 时使用)。选项包括

    1. LRU = 最近最少使用

    2. LFU = 最不常用

  • cache_load_factor (float = 0.2) – 当使用 EmbeddingLocation.MANAGED_CACHING 时,用于确定缓存容量的因子。缓存容量为 cache_load_factor * 所有嵌入表中的总行数

  • cache_sets (int = 0) – 缓存集的数量(当 EmbeddingLocation 设置为 MANAGED_CACHING 时使用)

  • cache_reserved_memory (float = 0.0) – 在 HBM 中为非缓存目的保留的内存量(当 EmbeddingLocation 设置为 MANAGED_CACHING 时使用)。

  • enforce_hbm (bool = False) – 如果为 True,则在使用 EmbeddingLocation.MANAGED_CACHING 时将所有权重/动量放置在 HBM 中

  • record_cache_metrics (Optional[RecordCacheMetrics] = None) – 如果 RecordCacheMetrics.record_cache_miss_counter 为 True,则记录命中次数、请求次数等;如果 RecordCacheMetrics.record_tablewise_cache_miss is True,则按表记录类似的指标

  • gather_uvm_cache_stats (Optional[bool] = False) – 如果为 True,则当 EmbeddingLocation 设置为 MANAGED_CACHING 时收集缓存统计信息

  • row_alignment (Optional[int] = None) – 行对齐

  • fp8_exponent_bits (Optional[int] = None) – 使用 FP8 时的指数位

  • fp8_exponent_bias (Optional[int] = None) – 使用 FP8 时的指数偏差

  • cache_assoc (int = 32) – 缓存的路数

  • scale_bias_size_in_bytes (int = DEFAULT_SCALE_BIAS_SIZE_IN_BYTES) – 缩放和偏差的大小(字节)

  • cacheline_alignment (bool = True) – 如果为 True,则将每个表对齐到 128b 缓存行边界

  • uvm_host_mapped (bool = False) – 如果为 True,则使用 malloc + cudaHostRegister 分配每个 UVM 张量。否则使用 cudaMallocManaged

  • reverse_qparam (bool = False) – 如果为 True,则在每行末尾加载 qparams。否则,在每行开头加载 qparams

  • feature_names_per_table (Optional[List[List[str]]] = None) – 一个可选列表,指定每个表的特征名称。feature_names_per_table[t] 指示表 t 的特征名称。

  • indices_dtype (torch.dtype = torch.int32) – 将传递给 forward() 调用的索引张量的预期 dtype。此信息将用于构建 remap_indices 数组/哈希。选项包括 torch.int32torch.int64

assign_embedding_weights(q_weight_list: List[Tuple[Tensor, Tensor | None]]) None[source]

使用来自输入权重和 scale_shifts 列表的值分配 self.split_embedding_weights()。

fill_random_weights() None[source]

逐表用随机权重填充缓冲区

forward(indices: Tensor, offsets: Tensor, per_sample_weights: Tensor | None = None) Tensor[source]

定义每次调用时执行的计算。

应由所有子类重写。

注意

虽然前向传递的配方需要在该函数内定义,但应该在此之后调用 Module 实例而不是此函数,因为前者负责运行注册的钩子,而后者会静默地忽略它们。

recompute_module_buffers() None[source]

计算元设备上的模块缓冲区,这些缓冲区在 reset_weights_placements_and_offsets() 中未实现。目前这些缓冲区是 weights_tysrows_per_tableD_offsetsbounds_check_warning。剪枝相关的或 uvm 相关的缓冲区目前未计算。

split_embedding_weights(split_scale_shifts: bool = True) List[Tuple[Tensor, Tensor | None]][source]

返回按表拆分的权重列表

split_embedding_weights_with_scale_bias(split_scale_bias_mode: int = 1) List[Tuple[Tensor, Tensor | None, Tensor | None]][source]

返回按表拆分的权重列表 split_scale_bias_mode

0:返回一行;1:返回权重 + scale_bias;2:返回权重、scale、bias。

其他 API

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得您的问题的解答

查看资源