• 文档 >
  • 合并嵌入运算符
快捷方式

合并嵌入运算符

稳定 API

torch.ops.fbgemm.merge_pooled_embeddings(pooled_embeddings, uncat_dim_size, target_device, cat_dim=1) Tensor

将来自不同设备(在同一主机上)的嵌入输出连接到目标设备上。

参数::
  • pooled_embeddings (List[Tensor]) – 来自同一主机上不同设备的嵌入输出列表。每个输出都具有 2 个维度。

  • uncat_dim_size (int) – 未连接的维度的尺寸,即如果 cat_dim=0,则 uncat_dim_size 是维度 1 的尺寸,反之亦然。

  • target_device (torch.device) – 聚合所有嵌入输出的目标设备。

  • cat_dim (int = 1) – 张量连接的维度

返回::

目标设备上的连接的嵌入输出(2D)

torch.ops.fbgemm.permute_pooled_embs(pooled_embs, offset_dim_list, permute_list, inv_offset_dim_list, inv_permute_list) Tensor

沿着特征维度排列嵌入输出。

嵌入输出张量 pooled_embs 包含批次中所有特征的嵌入输出。它以 2D 格式表示,其中行是批次尺寸维度,列是特征 * 嵌入维度。沿着特征维度排列本质上是沿着第二个维度(维度 1)排列。

参数::
  • pooled_embs (Tensor) – 要排列的嵌入输出。形状为 (B_local, total_global_D),其中 B_local = 本地批次尺寸,total_global_D 是所有特征(全局)的总嵌入维度

  • offset_dim_list (Tensor) – 所有特征的嵌入维度的完整累积和。形状为 T + 1,其中 T 是特征总数

  • permute_list (Tensor) – 描述每个特征如何排列的张量。 permute_list[i] 指示特征 permute_list[i] 被排列到位置 i

  • inv_offset_dim_list (Tensor) – 反向嵌入维度的完整累积和,这些是排列的嵌入维度。 inv_offset_dim_list[i] 表示特征 permute_list[i] 的起始嵌入位置

  • inv_permute_list (Tensor) – 反向排列列表,包含每个特征的排列位置。 inv_permute_list[i] 表示特征 i 的排列位置

返回::

排列后的嵌入输出 (Tensor)。与 pooled_embs 形状相同

示例

>>> import torch
>>> from itertools import accumulate
>>>
>>> # Suppose batch size = 3 and there are 3 features
>>> batch_size = 3
>>>
>>> # Embedding dimensions for each feature
>>> embs_dims = torch.tensor([4, 4, 8], dtype=torch.int64, device="cuda")
>>>
>>> # Permute list, i.e., move feature 2 to position 0, move feature 0
>>> # to position 1, so on
>>> permute = torch.tensor([2, 0, 1], dtype=torch.int64, device="cuda")
>>>
>>> # Compute embedding dim offsets
>>> offset_dim_list = torch.tensor([0] + list(accumulate(embs_dims)), dtype=torch.int64, device="cuda")
>>> print(offset_dim_list)
>>>
tensor([ 0,  4,  8, 16], device='cuda:0')
>>>
>>> # Compute inverse embedding dims
>>> inv_embs_dims = [embs_dims[p] for p in permute]
>>> # Compute complete cumulative sum of inverse embedding dims
>>> inv_offset_dim_list = torch.tensor([0] + list(accumulate(inv_embs_dims)), dtype=torch.int64, device="cuda")
>>> print(inv_offset_dim_list)
>>>
tensor([ 0,  8, 12, 16], device='cuda:0')
>>>
>>> # Compute inverse permutes
>>> inv_permute = [0] * len(permute)
>>> for i, p in enumerate(permute):
>>>     inv_permute[p] = i
>>> inv_permute_list = torch.tensor([inv_permute], dtype=torch.int64, device="cuda")
>>> print(inv_permute_list)
>>>
tensor([[1, 2, 0]], device='cuda:0')
>>>
>>> # Generate an example input
>>> pooled_embs = torch.arange(embs_dims.sum().item() * batch_size, dtype=torch.float32, device="cuda").reshape(batch_size, -1)
>>> print(pooled_embs)
>>>
tensor([[ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13.,
         14., 15.],
        [16., 17., 18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29.,
         30., 31.],
        [32., 33., 34., 35., 36., 37., 38., 39., 40., 41., 42., 43., 44., 45.,
         46., 47.]], device='cuda:0')
>>>
>>> torch.ops.fbgemm.permute_pooled_embs_auto_grad(pooled_embs, offset_dim_list, permute, inv_offset_dim_list, inv_permute_list)
>>>
tensor([[ 8.,  9., 10., 11., 12., 13., 14., 15.,  0.,  1.,  2.,  3.,  4.,  5.,
          6.,  7.],
        [24., 25., 26., 27., 28., 29., 30., 31., 16., 17., 18., 19., 20., 21.,
         22., 23.],
        [40., 41., 42., 43., 44., 45., 46., 47., 32., 33., 34., 35., 36., 37.,
         38., 39.]], device='cuda:0')

其他 API

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源