• 文档 >
  • 表格批量嵌入 (TBE) 训练模块
快捷方式

表格批量嵌入 (TBE) 训练模块

稳定 API

class fbgemm_gpu.split_table_batched_embeddings_ops_training.SplitTableBatchedEmbeddingBagsCodegen(embedding_specs: List[Tuple[int, int, EmbeddingLocation, ComputeDevice]], feature_table_map: List[int] | None = None, cache_algorithm: CacheAlgorithm = CacheAlgorithm.LRU, cache_load_factor: float = 0.2, cache_sets: int = 0, cache_reserved_memory: float = 0.0, cache_precision: SparseType = SparseType.FP32, weights_precision: SparseType = SparseType.FP32, output_dtype: SparseType = SparseType.FP32, enforce_hbm: bool = False, optimizer: EmbOptimType = EmbOptimType.EXACT_SGD, record_cache_metrics: RecordCacheMetrics | None = None, gather_uvm_cache_stats: bool | None = False, stochastic_rounding: bool = True, gradient_clipping: bool = False, max_gradient: float = 1.0, max_norm: float = 0.0, learning_rate: float = 0.01, eps: float = 1e-08, momentum: float = 0.9, weight_decay: float = 0.0, weight_decay_mode: WeightDecayMode = WeightDecayMode.NONE, eta: float = 0.001, beta1: float = 0.9, beta2: float = 0.999, ensemble_mode: EnsembleModeDefinition | None = None, counter_based_regularization: CounterBasedRegularizationDefinition | None = None, cowclip_regularization: CowClipDefinition | None = None, pooling_mode: PoolingMode = PoolingMode.SUM, device: str | device | int | None = None, bounds_check_mode: BoundsCheckMode = BoundsCheckMode.WARNING, uvm_non_rowwise_momentum: bool = False, use_experimental_tbe: bool = False, prefetch_pipeline: bool = False, stats_reporter_config: TBEStatsReporterConfig | None = None, table_names: List[str] | None = None, optimizer_state_dtypes: Dict[str, SparseType] | None = None, multipass_prefetch_config: MultiPassPrefetchConfig | None = None, global_weight_decay: GlobalWeightDecayDefinition | None = None, uvm_host_mapped: bool = False)[source]

表格批量嵌入 (TBE) 运算符。查找一个或多个嵌入表。该模块适用于训练。反向运算符与优化器融合。因此,嵌入表在反向传播期间更新。

参数:
  • embedding_specs (List[Tuple[int, int, EmbeddingLocation, ComputeDevice]]) –

    嵌入规范列表。每个规范描述物理嵌入表的规范。每个规范都是一个元组,包含嵌入行数、嵌入维度(必须是 4 的倍数)、表位置(EmbeddingLocation)和计算设备(ComputeDevice)。

    可用的 EmbeddingLocation 选项是

    1. DEVICE = 将嵌入表放置在 GPU 全局内存 (HBM) 中

    2. MANAGED = 将嵌入放置在统一虚拟内存中(GPU 和 CPU 都可以访问)

    3. MANAGED_CACHING = 将嵌入表放置在统一虚拟内存中,并使用 GPU 全局内存 (HBM) 作为缓存

    4. HOST = 将嵌入表放置在 CPU 内存 (DRAM) 中

    5. MTIA = 将嵌入表放置在 MTIA 内存中

    可用的 ComputeDevice 选项是

    1. CPU = 在 CPU 上执行表查找

    2. CUDA = 在 GPU 上执行表查找

    3. MTIA = 在 MTIA 上执行表查找

  • feature_table_map (Optional[List[int]] = None) – 一个可选列表,指定特征-表映射。feature_table_map[i] 指示特征 i 映射到的物理嵌入表。

  • cache_algorithm (CacheAlgorithm = CacheAlgorithm.LRU) –

    缓存算法(当 EmbeddingLocation 设置为 MANAGED_CACHING 时使用)。选项是

    1. LRU = 最近最少使用

    2. LFU = 最不常用

  • cache_load_factor (float = 0.2) – 用于确定缓存容量的因子(当使用 EmbeddingLocation.MANAGED_CACHING 时)。缓存容量为 cache_load_factor * 所有嵌入表中行的总数

  • cache_sets (int = 0) – 缓存集的数量(当 EmbeddingLocation 设置为 MANAGED_CACHING 时使用)

  • cache_reserved_memory (float = 0.0) – 为非缓存目的保留在 HBM 中的内存量(当 EmbeddingLocation 设置为 MANAGED_CACHING 时使用)。

  • cache_precision (SparseType = SparseType.FP32) – 缓存的数据类型(当 EmbeddingLocation 设置为 MANAGED_CACHING 时使用)。选项为 SparseType.FP32SparseType.FP16

  • weights_precision (SparseType = SparseType.FP32) – 嵌入表(也称为权重)的数据类型。选项为 SparseType.FP32SparseType.FP16

  • output_dtype (SparseType = SparseType.FP32) – 输出张量的数据类型。选项为 SparseType.FP32SparseType.FP16

  • enforce_hbm (bool = False) – 如果为 True,则在使用 EmbeddingLocation.MANAGED_CACHING 时将所有权重/动量放置在 HBM 中

  • optimizer (OptimType = OptimType.EXACT_SGD) –

    在反向传播中用于嵌入表更新的优化器。可用的 OptimType 选项是

    1. ADAM = Adam

    2. EXACT_ADAGRAD = Adagrad

    3. EXACT_ROWWISE_ADAGRAD = 行级 Adagrad

    4. EXACT_SGD = SGD

    5. LAMB = Lamb

    6. LARS_SGD = LARS-SGD

    7. PARTIAL_ROWWISE_ADAM = 部分行级 Adam

    8. PARTIAL_ROWWISE_LAMB = 部分行级 Lamb

    9. ENSEMBLE_ROWWISE_ADAGRAD = 集成行级 Adagrad

    10. NONE = 在反向传播中不应用优化器更新

    并输出稀疏权重梯度

  • record_cache_metrics (Optional[RecordCacheMetrics] = None) – 如果 RecordCacheMetrics.record_cache_miss_counter 为 True,则记录命中次数、请求次数等,如果 RecordCacheMetrics.record_tablewise_cache_miss 为 True,则按表记录类似的指标

  • gather_uvm_cache_stats (Optional[bool] = False) – 如果为 True,则在 EmbeddingLocation 设置为 MANAGED_CACHING 时收集缓存统计信息

  • stochastic_rounding (bool = True) – 如果为 True,则对不是 SparseType.FP32 的权重类型应用随机舍入

  • gradient_clipping (bool = False) – 如果为 True,则应用梯度裁剪

  • max_gradient (float = 1.0) – 梯度裁剪的值

  • max_norm (float = 0.0) – 最大范数值

  • learning_rate (float = 0.01) – 学习率

  • eps (float = 1.0e-8) – Adagrad、LAMB 和 Adam 使用的 epsilon 值。请注意,默认值与 torch.nn.optim.Adagrad 的默认值 1e-10 不同

  • momentum (float = 0.9) – LARS-SGD 使用的动量

  • weight_decay (float = 0.0) –

    LARS-SGD、LAMB、ADAM 和行级 Adagrad 使用的权重衰减。

    1. EXACT_ADAGRAD、SGD、EXACT_SGD 不支持权重衰减

    2. LAMB、ADAM、PARTIAL_ROWWISE_ADAM、PARTIAL_ROWWISE_LAMB、LARS_SGD 支持解耦权重衰减

    3. EXACT_ROWWISE_ADAGRAD 支持 L2 和解耦权重衰减(通过 weight_decay_mode)

  • weight_decay_mode (WeightDecayMode = WeightDecayMode.NONE) – 权重衰减模式。选项为 WeightDecayMode.NONEWeightDecayMode.L2WeightDecayMode.DECOUPLE

  • eta (float = 0.001) – LARS-SGD 使用的 eta 值

  • beta1 (float = 0.9) – LAMB 和 ADAM 使用的 beta1 值

  • beta2 (float = 0.999) – LAMB 和 ADAM 使用的 beta2 值

  • ensemble_mode (Optional[EnsembleModeDefinition] = None) – 集成行级 Adagrad 使用

  • counter_based_regularization (Optional[CounterBasedRegularizationDefinition] = None) – 行级 Adagrad 使用

  • cowclip_regularization (Optional[CowClipDefinition] = None) – 行级 Adagrad 使用

  • pooling_mode (PoolingMode = PoolingMode.SUM) –

    池化模式。可用的 PoolingMode 选项是

    1. SUM = 求和池化

    2. MEAN = 平均池化

    3. NONE = 无池化(序列嵌入)

  • device (Optional[str, int, torch.device]] = None) – 要将张量放置在其上的当前设备

  • bounds_check_mode (BoundsCheckMode = BoundsCheckMode.WARNING) –

    输入检查模式。可用的 BoundsCheckMode 选项是

    1. NONE = 跳过边界检查

    2. FATAL = 遇到无效索引/偏移量时抛出错误

    3. WARNING = 遇到无效索引/偏移量时打印警告消息并修复它(将无效索引设置为零并调整无效偏移量使其在范围内)

    4. IGNORE = 静默修复无效索引/偏移量(将无效索引设置为零并调整无效偏移量使其在范围内)

  • uvm_non_rowwise_momentum (bool = False) – 如果为 True,则在统一虚拟内存上放置非行主导动量

  • use_experimental_tbe (bool = False) – 如果为 True,则使用优化的 TBE 实现(TBE v2)。请注意,这仅在 NVIDIA GPU 上受支持。

  • prefetch_pipeline (bool = False) – 如果为 True,则在使用 EmbeddingLocation.MANAGED_CACHING 时启用缓存预取流水线。目前仅支持 LRU 缓存策略。如果使用单独的流进行预取,则必须设置预取函数的可选 forward_stream 参数。

  • stats_reporter_config (Optional[TBEStatsReporterConfig] = None) – TBE 统计报告器的配置

  • table_names (Optional[List[str]] = None) – 此 TBE 中嵌入表名称的列表

  • optimizer_state_dtypes (Optional[Dict[str, SparseType]] = None) – 优化器状态数据类型字典。键是优化器状态名称,值是其对应的类型

  • multipass_prefetch_config (Optional[MultiPassPrefetchConfig] = None) – 多遍缓存预取的配置(当使用 EmbeddingLocation.MANAGED_CACHING 时)

  • global_weight_decay (Optional[GlobalWeightDecayDefinition] = None) – 全局权重衰减的配置

  • uvm_host_mapped (bool = False) – 如果为 True,则使用 malloc + cudaHostRegister 分配每个 UVM 张量。否则使用 cudaMallocManaged

forward(indices: Tensor, offsets: Tensor, per_sample_weights: Tensor | None = None, feature_requires_grad: Tensor | None = None, batch_size_per_feature_per_rank: List[List[int]] | None = None, total_unique_indices: int | None = None) Tensor[source]

前向传递函数,

  1. 执行输入边界检查

  2. 生成必要的可变批量大小嵌入 (VBE) 元数据(如果使用 VBE)

  3. 从 UVM 预取数据到缓存(如果使用 EmbeddingLocation.MANAGED_CACHING 且用户未显式预取数据)

  4. 通过调用相应的 Autograd 函数(基于所选优化器)执行嵌入表查找

参数:
  • indices (Tensor) – 一个 1D 张量,包含要从所有嵌入表中查找的索引

  • offsets (Tensor) – 一个 1D 张量,包含索引的偏移量。形状 (B * T + 1),其中 B = 批次大小,T = 特征数量。offsets[t * B + b + 1] - offsets[t * B + b] 是特征 t 的包 b 的长度

  • per_sample_weights (Optional[Tensor]) – 一个可选的 1D 浮点张量,包含每个样本的权重。如果为 None,则将执行**无权重**嵌入查找。否则,将使用**加权**。此张量的长度必须与 indices 张量的长度相同。per_sample_weights[i] 的值将用于乘以查找的行 indices[i] 中的每个元素,其中 0 <= i < len(per_sample_weights)

  • feature_requires_grad (Optional[Tensor]) – 一个可选的 1D 张量,用于指示 per_sample_weights 是否需要梯度。张量的长度必须等于特征的数量

  • batch_size_per_feature_per_rank (Optional[List[List[int]]]) – 一个可选的 2D 张量,包含每个秩和每个特征的批次大小。如果为 None,则 TBE 假设**每个特征具有相同的批次大小**,并根据 offsets 形状计算批次大小。否则,TBE 假设不同的特征可以具有不同的批次大小,并使用**可变批次大小嵌入查找模式 (VBE)**。形状(特征数量,秩数量)。batch_size_per_feature_per_rank[f][r] 表示特征 f 和秩 r 的批次大小

  • total_unique_indices (Optional[int]) – 一个可选的整数,表示唯一索引的总数。在使用 OptimType.NONE 时必须设置此值。这是因为 TBE 需要此信息才能在反向传递中分配权重梯度张量。

返回值:

一个包含查找数据的 2D 张量。形状 (B, total_D),其中 B = 批次大小,total_D = 表中所有嵌入维度的总和

示例

>>> import torch
>>>
>>> from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
>>>    EmbeddingLocation,
>>> )
>>> from fbgemm_gpu.split_table_batched_embeddings_ops_training import (
>>>    SplitTableBatchedEmbeddingBagsCodegen,
>>>    ComputeDevice,
>>> )
>>>
>>> # Two tables
>>> embedding_specs = [
>>>     (3, 8, EmbeddingLocation.DEVICE, ComputeDevice.CUDA),
>>>     (5, 4, EmbeddingLocation.MANAGED, ComputeDevice.CUDA)
>>> ]
>>>
>>> tbe = SplitTableBatchedEmbeddingBagsCodegen(embedding_specs)
>>> tbe.init_embedding_weights_uniform(-1, 1)
>>>
>>> print(tbe.split_embedding_weights())
[tensor([[-0.9426,  0.7046,  0.4214, -0.0419,  0.1331, -0.7856, -0.8124, -0.2021],
        [-0.5771,  0.5911, -0.7792, -0.1068, -0.6203,  0.4813, -0.1677,  0.4790],
        [-0.5587, -0.0941,  0.5754,  0.3475, -0.8952, -0.1964,  0.0810, -0.4174]],
       device='cuda:0'), tensor([[-0.2513, -0.4039, -0.3775,  0.3273],
        [-0.5399, -0.0229, -0.1455, -0.8770],
        [-0.9520,  0.4593, -0.7169,  0.6307],
        [-0.1765,  0.8757,  0.8614,  0.2051],
        [-0.0603, -0.9980, -0.7958, -0.5826]], device='cuda:0')]
>>> # Batch size = 3
>>> indices = torch.tensor([0, 1, 2, 0, 1, 2, 0, 3, 1, 4, 2, 0, 0],
>>>                        device="cuda",
>>>                        dtype=torch.long)
>>> offsets = torch.tensor([0, 2, 5, 7, 9, 12, 13],
>>>                        device="cuda",
>>>                        dtype=torch.long)
>>>
>>> output = tbe(indices, offsets)
>>>
>>> # Batch size = 3, total embedding dimension = 12
>>> print(output.shape)
torch.Size([3, 12])
>>> print(output)
tensor([[-1.5197,  1.2957, -0.3578, -0.1487, -0.4873, -0.3044, -0.9801,  0.2769,
         -0.7164,  0.8528,  0.7159, -0.6719],
        [-2.0784,  1.2016,  0.2176,  0.1988, -1.3825, -0.5008, -0.8991, -0.1405,
         -1.2637, -0.9427, -1.8902,  0.3754],
        [-1.5013,  0.6105,  0.9968,  0.3057, -0.7621, -0.9821, -0.7314, -0.6195,
         -0.2513, -0.4039, -0.3775,  0.3273]], device='cuda:0',
       grad_fn=<CppNode<SplitLookupFunction_sgd_Op>>)
set_learning_rate(lr: float) None[source]

设置学习率。

参数:

lr (float) – 要设置的学习率值

set_optimizer_step(step: int) None[source]

设置优化器步长。

参数:

step (int) – 要设置的步长值

split_embedding_weights() List[Tensor][source]

返回按表拆分的嵌入权重(视图)列表

返回值:

权重列表。长度 = 表格数量

split_optimizer_states() List[List[Tensor]][source]

返回按表拆分的优化器状态(视图)列表

返回值:

状态列表列表。形状 =(表数量,状态数量)。

以下显示了每个优化器的状态列表(按返回顺序)

  1. ADAM: momentum1, momentum2

  2. EXACT_ADAGRAD: momentum1

  3. EXACT_ROWWISE_ADAGRAD: momentum1(行方向),prev_iter(行方向;仅当使用 WeightDecayMode = COUNTERCOWCLIPglobal_weight_decay 不为 None 时),row_counter(行方向;仅当使用 WeightDecayMode = COUNTERCOWCLIP 时)

  4. EXACT_SGD: 无状态

  5. LAMB: momentum1momentum2

  6. LARS_SGD: momentum1

  7. PARTIAL_ROWWISE_ADAM: momentum1momentum2(行方向)

  8. PARTIAL_ROWWISE_LAMB: momentum1momentum2(行方向)

  9. ENSEMBLE_ROWWISE_ADAGRAD: momentum1(行方向),momentum2

  10. NONE: 无状态(抛出错误)

update_hyper_parameters(params_dict: Dict[str, float]) None[source]

从外部控制流设置超参数。

参数:

params_dict (Dict[str, float]) – 包含超参数名称及其值的字典

其他 API

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源