Table Batched Embedding (TBE) 训练模块¶
稳定 API¶
- class fbgemm_gpu.split_table_batched_embeddings_ops_training.SplitTableBatchedEmbeddingBagsCodegen(embedding_specs: List[Tuple[int, int, EmbeddingLocation, ComputeDevice]], feature_table_map: List[int] | None = None, cache_algorithm: CacheAlgorithm = CacheAlgorithm.LRU, cache_load_factor: float = 0.2, cache_sets: int = 0, cache_reserved_memory: float = 0.0, cache_precision: SparseType | None = None, weights_precision: SparseType = SparseType.FP32, output_dtype: SparseType = SparseType.FP32, enforce_hbm: bool = False, optimizer: EmbOptimType = EmbOptimType.EXACT_SGD, record_cache_metrics: RecordCacheMetrics | None = None, gather_uvm_cache_stats: bool | None = False, stochastic_rounding: bool = True, gradient_clipping: bool = False, max_gradient: float = 1.0, max_norm: float = 0.0, learning_rate: float = 0.01, eps: float = 1e-08, momentum: float = 0.9, weight_decay: float = 0.0, weight_decay_mode: WeightDecayMode = WeightDecayMode.NONE, eta: float = 0.001, beta1: float = 0.9, beta2: float = 0.999, ensemble_mode: EnsembleModeDefinition | None = None, emainplace_mode: EmainplaceModeDefinition | None = None, counter_based_regularization: CounterBasedRegularizationDefinition | None = None, cowclip_regularization: CowClipDefinition | None = None, pooling_mode: PoolingMode = PoolingMode.SUM, device: str | device | int | None = None, bounds_check_mode: BoundsCheckMode = BoundsCheckMode.WARNING, uvm_non_rowwise_momentum: bool = False, use_experimental_tbe: bool = False, prefetch_pipeline: bool = False, stats_reporter_config: TBEStatsReporterConfig | None = None, table_names: List[str] | None = None, optimizer_state_dtypes: Dict[str, SparseType] | None = None, multipass_prefetch_config: MultiPassPrefetchConfig | None = None, global_weight_decay: GlobalWeightDecayDefinition | None = None, uvm_host_mapped: bool = False, extra_optimizer_config: UserEnabledConfigDefinition | None = None, tbe_input_multiplexer_config: TBEInputMultiplexerConfig | None = None, embedding_table_index_type: dtype = torch.int64, embedding_table_offset_type: dtype = torch.int64, embedding_shard_info: List[Tuple[int, int, int, int]] | None = None)[源代码]¶
Table Batched Embedding (TBE) 算子。查找一个或多个嵌入表。此模块应用于训练。反向算子与优化器融合。因此,嵌入表会在反向传播期间更新。
- 参数:
embedding_specs (List[Tuple[int, int, EmbeddingLocation, ComputeDevice]]) –
嵌入规范列表。每个规范描述了一个物理嵌入表的规格。每个规范是一个元组,包含嵌入行的数量、嵌入维度(必须是 4 的倍数)、表位置(EmbeddingLocation)和计算设备(ComputeDevice)。
可用的 EmbeddingLocation 选项包括
DEVICE = 将嵌入表放置在 GPU 全局内存 (HBM) 中
MANAGED = 将嵌入表放置在统一虚拟内存中(GPU 和 CPU 均可访问)
MANAGED_CACHING = 将嵌入表放置在统一虚拟内存中,并使用 GPU 全局内存 (HBM) 作为缓存
HOST = 将嵌入表放置在 CPU 内存 (DRAM) 中
MTIA = 将嵌入表放置在 MTIA 内存中
可用的 ComputeDevice 选项包括
CPU = 在 CPU 上执行表查找
CUDA = 在 GPU 上执行表查找
MTIA = 在 MTIA 上执行表查找
feature_table_map (Optional[List[int]] = None) – 可选列表,指定特征到表的映射。feature_table_map[i] 表示特征 i 映射到的物理嵌入表。
cache_algorithm (CacheAlgorithm = CacheAlgorithm.LRU) –
缓存算法(当 EmbeddingLocation 设置为 MANAGED_CACHING 时使用)。选项包括
LRU = 最近最少使用
LFU = 最不常用
cache_load_factor (float = 0.2) – 用于确定使用 EmbeddingLocation.MANAGED_CACHING 时的缓存容量的因子。缓存容量为 cache_load_factor * 所有嵌入表中的总行数。
cache_sets (int = 0) – 缓存集的数量(当 EmbeddingLocation 设置为 MANAGED_CACHING 时使用)。
cache_reserved_memory (float = 0.0) – 在 HBM 中为非缓存目的保留的内存量(当 EmbeddingLocation 设置为 MANAGED_CACHING 时使用)。
cache_precision (SparseType = SparseType.FP32) – 缓存的数据类型(当 EmbeddingLocation 设置为 MANAGED_CACHING 时使用)。选项为 SparseType.FP32 和 SparseType.FP16。
weights_precision (SparseType = SparseType.FP32) – 嵌入表(也称为权重)的数据类型。选项为 SparseType.FP32 和 SparseType.FP16。
output_dtype (SparseType = SparseType.FP32) – 输出张量的数据类型。选项为 SparseType.FP32 和 SparseType.FP16。
enforce_hbm (bool = False) – 如果为 True,则在使用 EmbeddingLocation.MANAGED_CACHING 时将所有权重/动量放置在 HBM 中。
optimizer (OptimType = OptimType.EXACT_SGD) –
在反向传播中用于更新嵌入表的优化器。可用的 OptimType 选项包括
ADAM = Adam
EXACT_ADAGRAD = Adagrad
EXACT_ROWWISE_ADAGRAD = 按行 Adagrad
EXACT_SGD = SGD
LAMB = Lamb
LARS_SGD = LARS-SGD
PARTIAL_ROWWISE_ADAM = 部分按行 Adam
PARTIAL_ROWWISE_LAMB = 部分按行 Lamb
ENSEMBLE_ROWWISE_ADAGRAD = 集成按行 Adagrad
EMAINPLACE_ROWWISE_ADAGRAD = EMA 就地按行 Adagrad
NONE = 在反向传播中不应用优化器更新
并输出稀疏权重梯度
record_cache_metrics (Optional[RecordCacheMetrics] = None) – 记录命中次数、请求次数等,如果 RecordCacheMetrics.record_cache_miss_counter 为 True;如果 RecordCacheMetrics.record_tablewise_cache_miss 为 True,则按表记录类似指标。
gather_uvm_cache_stats (Optional[bool] = False) – 如果为 True,则当 EmbeddingLocation 设置为 MANAGED_CACHING 时收集缓存统计信息。
stochastic_rounding (bool = True) – 如果为 True,则对非 SparseType.FP32 的权重类型应用随机舍入。
gradient_clipping (bool = False) – 如果为 True,则应用梯度裁剪。
max_gradient (float = 1.0) – 梯度裁剪的值。
max_norm (float = 0.0) – 最大范数值。
learning_rate (float = 0.01) – 学习率。
eps (float = 1.0e-8) – Adagrad、LAMB 和 Adam 使用的 epsilon 值。注意,此默认值与 torch.nn.optim.Adagrad 的默认值 1e-10 不同。
momentum (float = 0.9) – LARS-SGD 使用的动量。
weight_decay (float = 0.0) –
LARS-SGD、LAMB、ADAM 和按行 Adagrad 使用的权重衰减。
EXACT_ADAGRAD、SGD、EXACT_SGD 不支持权重衰减
LAMB、ADAM、PARTIAL_ROWWISE_ADAM、PARTIAL_ROWWISE_LAMB、LARS_SGD 支持解耦权重衰减
EXACT_ROWWISE_ADAGRAD 支持 L2 和解耦权重衰减(通过 weight_decay_mode)
weight_decay_mode (WeightDecayMode = WeightDecayMode.NONE) – 权重衰减模式。选项为 WeightDecayMode.NONE、WeightDecayMode.L2 和 WeightDecayMode.DECOUPLE。
eta (float = 0.001) – LARS-SGD 使用的 eta 值。
beta1 (float = 0.9) – LAMB 和 ADAM 使用的 beta1 值。
beta2 (float = 0.999) – LAMB 和 ADAM 使用的 beta2 值。
ensemble_mode (Optional[EnsembleModeDefinition] = None) – 由集成按行 Adagrad 使用。
emainplace_mode (Optional[EmainplaceModeDefinition] = None) – 由 EMA 就地按行 Adagrad 使用。
counter_based_regularization (Optional[CounterBasedRegularizationDefinition] = None) – 由按行 Adagrad 使用。
cowclip_regularization (Optional[CowClipDefinition] = None) – 由按行 Adagrad 使用。
pooling_mode (PoolingMode = PoolingMode.SUM) –
池化模式。可用的 PoolingMode 选项包括
SUM = 求和池化
MEAN = 平均值池化
NONE = 不进行池化(序列嵌入)
device (Optional[Union[str, int, torch.device]] = None) – 当前放置张量的设备。
bounds_check_mode (BoundsCheckMode = BoundsCheckMode.WARNING) –
输入检查模式。可用的 BoundsCheckMode 选项包括
NONE = 跳过边界检查
FATAL = 遇到无效索引/偏移量时抛出错误
WARNING = 遇到无效索引/偏移量时打印警告消息并修复(将无效索引设置为零,并将无效偏移量调整到边界内)
IGNORE = 静默修复无效索引/偏移量(将无效索引设置为零,并将无效偏移量调整到边界内)
uvm_non_rowwise_momentum (bool = False) – 如果为 True,则将非按行动量放置在统一虚拟内存中。
use_experimental_tbe (bool = False) – 如果为 True,则使用优化的 TBE 实现(TBE v2)。请注意,这仅支持 NVIDIA GPU。
prefetch_pipeline (bool = False) – 如果为 True,则在使用 EmbeddingLocation.MANAGED_CACHING 时启用缓存预取流水线。目前仅支持 LRU 缓存策略。如果使用单独的流进行预取,则必须设置预取函数的可选参数 forward_stream。
stats_reporter_config (Optional[TBEStatsReporterConfig] = None) – TBE 统计报告器的配置。
table_names (Optional[List[str]] = None) – 此 TBE 中的嵌入表名称列表。
optimizer_state_dtypes (Optional[Dict[str, SparseType]] = None) – 优化器状态数据类型字典。键是优化器状态名称,值是其对应的类型
multipass_prefetch_config (Optional[MultiPassPrefetchConfig] = None) – 用于多遍缓存预取的配置(当使用 EmbeddingLocation.MANAGED_CACHING 时)
global_weight_decay (Optional[GlobalWeightDecayDefinition] = None) – 用于全局权重衰减的配置
uvm_host_mapped (bool = False) – 如果为 True,则使用 malloc + cudaHostRegister 分配每个 UVM 张量。否则使用 cudaMallocManaged
None) (extra_optimizer_config Optional[UserEnabledConfigDefinition] =) –
一个额外的配置,用于为优化器启用某些模式。这些模式默认不启用。- 在 Adam 中使用 use_rowwise_bias_correction 启用逐行偏差校正
计算
embedding_table_index_type (torch.dtype = torch.int64) – 嵌入表索引张量的数据类型。选项包括 torch.int32 和 torch.int64
embedding_table_offset_type (torch.dtype = torch.int64) – 嵌入表偏移张量的数据类型。选项包括 torch.int32 和 torch.int64
embedding_shard_info (Optional[List[Tuple[int, int, int, int]]] = None) – 关于分片位置和预分片表大小的信息。如果未设置,则表不分片。(preshard_table_height, preshard_table_dim, height_offset, dim_offset)
- forward(indices: Tensor, offsets: Tensor, per_sample_weights: Tensor | None = None, feature_requires_grad: Tensor | None = None, batch_size_per_feature_per_rank: List[List[int]] | None = None, total_unique_indices: int | None = None) Tensor [source]¶
正向传播函数,它执行以下操作:
执行输入边界检查
生成必要的变长批量嵌入 (VBE) 元数据(如果使用 VBE)
将数据从 UVM 预取到缓存(如果使用 EmbeddingLocation.MANAGED_CACHING 且用户未明确预取数据)
通过调用相应的 Autograd 函数(根据所选优化器)执行嵌入表查找
- 参数:
indices (Tensor) – 一个 1D 张量,包含要从所有嵌入表中查找的索引
offsets (Tensor) – 一个 1D 张量,包含索引的偏移量。形状为 (B * T + 1),其中 B = 批量大小,T = 特征数量。offsets[t * B + b + 1] - offsets[t * B + b] 是特征 t 中 bag b 的长度
per_sample_weights (Optional[Tensor]) – 一个可选的 1D float 张量,包含每个样本的权重。如果为 None,将执行无权重嵌入查找。否则,将使用加权查找。此张量的长度必须与 indices 张量的长度相同。per_sample_weights[i] 的值将用于乘以查找的行 indices[i] 中的每个元素,其中 0 <= i < len(per_sample_weights)。
feature_requires_grad (Optional[Tensor]) – 一个可选的 1D 张量,用于指示 per_sample_weights 是否需要梯度。张量的长度必须等于特征数量
batch_size_per_feature_per_rank (Optional[List[List[int]]]) – 一个可选的 2D 张量,包含每个 rank 和每个特征的批量大小。如果为 None,TBE 假定每个特征具有相同的批量大小,并从 offsets 形状计算批量大小。否则,TBE 假定不同特征可以具有不同的批量大小,并使用变长批量嵌入查找模式 (VBE)。形状为(特征数量,rank 数量)。batch_size_per_feature_per_rank[f][r] 表示特征 f 和 rank r 的批量大小
total_unique_indices (Optional[int]) – 一个可选的整数,表示唯一索引的总数。当使用 OptimType.NONE 时,必须设置此值。这是因为 TBE 在反向传播中分配权重梯度张量需要此信息。
- 返回:
一个包含查找数据的 2D 张量。形状为 (B, total_D),其中 B = 批量大小,total_D = 表中所有嵌入维度的总和
示例
>>> import torch >>> >>> from fbgemm_gpu.split_table_batched_embeddings_ops_common import ( >>> EmbeddingLocation, >>> ) >>> from fbgemm_gpu.split_table_batched_embeddings_ops_training import ( >>> SplitTableBatchedEmbeddingBagsCodegen, >>> ComputeDevice, >>> ) >>> >>> # Two tables >>> embedding_specs = [ >>> (3, 8, EmbeddingLocation.DEVICE, ComputeDevice.CUDA), >>> (5, 4, EmbeddingLocation.MANAGED, ComputeDevice.CUDA) >>> ] >>> >>> tbe = SplitTableBatchedEmbeddingBagsCodegen(embedding_specs) >>> tbe.init_embedding_weights_uniform(-1, 1) >>> >>> print(tbe.split_embedding_weights()) [tensor([[-0.9426, 0.7046, 0.4214, -0.0419, 0.1331, -0.7856, -0.8124, -0.2021], [-0.5771, 0.5911, -0.7792, -0.1068, -0.6203, 0.4813, -0.1677, 0.4790], [-0.5587, -0.0941, 0.5754, 0.3475, -0.8952, -0.1964, 0.0810, -0.4174]], device='cuda:0'), tensor([[-0.2513, -0.4039, -0.3775, 0.3273], [-0.5399, -0.0229, -0.1455, -0.8770], [-0.9520, 0.4593, -0.7169, 0.6307], [-0.1765, 0.8757, 0.8614, 0.2051], [-0.0603, -0.9980, -0.7958, -0.5826]], device='cuda:0')]
>>> # Batch size = 3 >>> indices = torch.tensor([0, 1, 2, 0, 1, 2, 0, 3, 1, 4, 2, 0, 0], >>> device="cuda", >>> dtype=torch.long) >>> offsets = torch.tensor([0, 2, 5, 7, 9, 12, 13], >>> device="cuda", >>> dtype=torch.long) >>> >>> output = tbe(indices, offsets) >>> >>> # Batch size = 3, total embedding dimension = 12 >>> print(output.shape) torch.Size([3, 12])
>>> print(output) tensor([[-1.5197, 1.2957, -0.3578, -0.1487, -0.4873, -0.3044, -0.9801, 0.2769, -0.7164, 0.8528, 0.7159, -0.6719], [-2.0784, 1.2016, 0.2176, 0.1988, -1.3825, -0.5008, -0.8991, -0.1405, -1.2637, -0.9427, -1.8902, 0.3754], [-1.5013, 0.6105, 0.9968, 0.3057, -0.7621, -0.9821, -0.7314, -0.6195, -0.2513, -0.4039, -0.3775, 0.3273]], device='cuda:0', grad_fn=<CppNode<SplitLookupFunction_sgd_Op>>)
- split_optimizer_states() List[List[Tensor]] [source]¶
返回一个优化器状态列表(视图),按表分割
- 返回:
状态列表的列表。形状 =(表的数量,状态的数量)。
以下显示了每个优化器的状态列表(按返回顺序)
ADAM:momentum1,momentum2
EXACT_ADAGRAD:momentum1
EXACT_ROWWISE_ADAGRAD:momentum1(逐行),prev_iter(逐行;仅当使用 WeightDecayMode = COUNTER 或 COWCLIP 或 global_weight_decay 不为 None 时),row_counter(逐行;仅当使用 WeightDecayMode = COUNTER 或 COWCLIP 时)
EXACT_SGD:无状态
LAMB:momentum1,momentum2
LARS_SGD:momentum1
PARTIAL_ROWWISE_ADAM:momentum1,momentum2(逐行)
PARTIAL_ROWWISE_LAMB:momentum1,momentum2(逐行)
ENSEMBLE_ROWWISE_ADAGRAD:momentum1(逐行),momentum2
NONE:无状态(抛出错误)