合并嵌入运算符¶
稳定 API¶
- torch.ops.fbgemm.merge_pooled_embeddings(pooled_embeddings, uncat_dim_size, target_device, cat_dim=1) Tensor ¶
将来自不同设备(在同一主机上)的嵌入输出连接到目标设备上。
- 参数::
pooled_embeddings (List[Tensor]) – 来自同一主机上不同设备的嵌入输出列表。每个输出都具有 2 个维度。
uncat_dim_size (int) – 未连接的维度的尺寸,即如果 cat_dim=0,则 uncat_dim_size 是维度 1 的尺寸,反之亦然。
target_device (torch.device) – 聚合所有嵌入输出的目标设备。
cat_dim (int = 1) – 张量连接的维度
- 返回::
目标设备上的连接的嵌入输出(2D)
- torch.ops.fbgemm.permute_pooled_embs(pooled_embs, offset_dim_list, permute_list, inv_offset_dim_list, inv_permute_list) Tensor ¶
沿着特征维度排列嵌入输出。
嵌入输出张量 pooled_embs 包含批次中所有特征的嵌入输出。它以 2D 格式表示,其中行是批次尺寸维度,列是特征 * 嵌入维度。沿着特征维度排列本质上是沿着第二个维度(维度 1)排列。
- 参数::
pooled_embs (Tensor) – 要排列的嵌入输出。形状为 (B_local, total_global_D),其中 B_local = 本地批次尺寸,total_global_D 是所有特征(全局)的总嵌入维度
offset_dim_list (Tensor) – 所有特征的嵌入维度的完整累积和。形状为 T + 1,其中 T 是特征总数
permute_list (Tensor) – 描述每个特征如何排列的张量。 permute_list[i] 指示特征 permute_list[i] 被排列到位置 i
inv_offset_dim_list (Tensor) – 反向嵌入维度的完整累积和,这些是排列的嵌入维度。 inv_offset_dim_list[i] 表示特征 permute_list[i] 的起始嵌入位置
inv_permute_list (Tensor) – 反向排列列表,包含每个特征的排列位置。 inv_permute_list[i] 表示特征 i 的排列位置
- 返回::
排列后的嵌入输出 (Tensor)。与 pooled_embs 形状相同
示例
>>> import torch >>> from itertools import accumulate >>> >>> # Suppose batch size = 3 and there are 3 features >>> batch_size = 3 >>> >>> # Embedding dimensions for each feature >>> embs_dims = torch.tensor([4, 4, 8], dtype=torch.int64, device="cuda") >>> >>> # Permute list, i.e., move feature 2 to position 0, move feature 0 >>> # to position 1, so on >>> permute = torch.tensor([2, 0, 1], dtype=torch.int64, device="cuda") >>> >>> # Compute embedding dim offsets >>> offset_dim_list = torch.tensor([0] + list(accumulate(embs_dims)), dtype=torch.int64, device="cuda") >>> print(offset_dim_list) >>> tensor([ 0, 4, 8, 16], device='cuda:0') >>> >>> # Compute inverse embedding dims >>> inv_embs_dims = [embs_dims[p] for p in permute] >>> # Compute complete cumulative sum of inverse embedding dims >>> inv_offset_dim_list = torch.tensor([0] + list(accumulate(inv_embs_dims)), dtype=torch.int64, device="cuda") >>> print(inv_offset_dim_list) >>> tensor([ 0, 8, 12, 16], device='cuda:0') >>> >>> # Compute inverse permutes >>> inv_permute = [0] * len(permute) >>> for i, p in enumerate(permute): >>> inv_permute[p] = i >>> inv_permute_list = torch.tensor([inv_permute], dtype=torch.int64, device="cuda") >>> print(inv_permute_list) >>> tensor([[1, 2, 0]], device='cuda:0') >>> >>> # Generate an example input >>> pooled_embs = torch.arange(embs_dims.sum().item() * batch_size, dtype=torch.float32, device="cuda").reshape(batch_size, -1) >>> print(pooled_embs) >>> tensor([[ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.], [16., 17., 18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29., 30., 31.], [32., 33., 34., 35., 36., 37., 38., 39., 40., 41., 42., 43., 44., 45., 46., 47.]], device='cuda:0') >>> >>> torch.ops.fbgemm.permute_pooled_embs_auto_grad(pooled_embs, offset_dim_list, permute, inv_offset_dim_list, inv_permute_list) >>> tensor([[ 8., 9., 10., 11., 12., 13., 14., 15., 0., 1., 2., 3., 4., 5., 6., 7.], [24., 25., 26., 27., 28., 29., 30., 31., 16., 17., 18., 19., 20., 21., 22., 23.], [40., 41., 42., 43., 44., 45., 46., 47., 32., 33., 34., 35., 36., 37., 38., 39.]], device='cuda:0')