池化嵌入模块¶
稳定 API¶
- class fbgemm_gpu.permute_pooled_embedding_modules.PermutePooledEmbeddings(embs_dims: List[int], permute: List[int], device: device | None = None)[source]¶
一个用于沿特征维度排列嵌入输出的模块
一个嵌入输出张量包含批次中所有特征的嵌入输出。它以 2D 格式表示,其中行是批次大小维度,列是特征 * 嵌入维度。沿特征维度排列本质上是在第二维度(维度 1)上排列。
示例
>>> import torch >>> import fbgemm_gpu >>> from fbgemm_gpu.permute_pooled_embedding_modules import PermutePooledEmbeddings >>> >>> # Suppose batch size = 3 and there are 3 features >>> batch_size = 3 >>> >>> # Embedding dimensions for each feature >>> embs_dims = torch.tensor([4, 4, 8], dtype=torch.int64, device="cuda") >>> >>> # Permute list, i.e., move feature 2 to position 0, move feature 0 >>> # to position 1, so on >>> permute = [2, 0, 1] >>> >>> # Instantiate the module >>> perm = PermutePooledEmbeddings(embs_dims, permute) >>> >>> # Generate an example input >>> pooled_embs = torch.arange( >>> embs_dims.sum().item() * batch_size, >>> dtype=torch.float32, device="cuda" >>> ).reshape(batch_size, -1) >>> print(pooled_embs) >>> tensor([[ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.], [16., 17., 18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29., 30., 31.], [32., 33., 34., 35., 36., 37., 38., 39., 40., 41., 42., 43., 44., 45., 46., 47.]], device='cuda:0') >>> >>> # Invoke >>> perm(pooled_embs) >>> tensor([[ 8., 9., 10., 11., 12., 13., 14., 15., 0., 1., 2., 3., 4., 5., 6., 7.], [24., 25., 26., 27., 28., 29., 30., 31., 16., 17., 18., 19., 20., 21., 22., 23.], [40., 41., 42., 43., 44., 45., 46., 47., 32., 33., 34., 35., 36., 37., 38., 39.]], device='cuda:0')
- 参数:
embs_dims (List[int]) – 所有特征的嵌入维度列表。长度 = 特征数量
permute (List[int]) – 描述每个特征如何排列的列表。 permute[i] 将特征 permute[i] 排列到位置 i。
device (Optional[torch.device] = None) – 在该设备上运行此模块