快捷方式

模块

标准 TorchRec 模块表示嵌入表集合。

  • EmbeddingBagCollectiontorch.nn.EmbeddingBag 的集合。

  • EmbeddingCollectiontorch.nn.Embedding 的集合。

这些模块是通过标准化的配置类构建的。

  • EmbeddingBagConfig 用于 EmbeddingBagCollection

  • EmbeddingConfig 用于 EmbeddingCollection

class torchrec.modules.embedding_configs.EmbeddingBagConfig(num_embeddings: int, embedding_dim: int, name: str = '', data_type: ~torchrec.types.DataType = DataType.FP32, feature_names: ~typing.List[str] = <factory>, weight_init_max: ~typing.Optional[float] = None, weight_init_min: ~typing.Optional[float] = None, num_embeddings_post_pruning: ~typing.Optional[int] = None, init_fn: ~typing.Optional[~typing.Callable[[~torch.Tensor], ~typing.Optional[~torch.Tensor]]] = None, need_pos: bool = False, pooling: ~torchrec.modules.embedding_configs.PoolingType = PoolingType.SUM)

基类:BaseEmbeddingConfig

EmbeddingBagConfig 是一个数据类,表示单个嵌入表,其中输出旨在进行池化。

参数:

pooling (PoolingType) – 池化类型。

class torchrec.modules.embedding_configs.EmbeddingConfig(num_embeddings: int, embedding_dim: int, name: str = '', data_type: ~torchrec.types.DataType = DataType.FP32, feature_names: ~typing.List[str] = <factory>, weight_init_max: ~typing.Optional[float] = None, weight_init_min: ~typing.Optional[float] = None, num_embeddings_post_pruning: ~typing.Optional[int] = None, init_fn: ~typing.Optional[~typing.Callable[[~torch.Tensor], ~typing.Optional[~torch.Tensor]]] = None, need_pos: bool = False)

基类:BaseEmbeddingConfig

EmbeddingConfig 是一个数据类,表示单个嵌入表。

class torchrec.modules.embedding_configs.BaseEmbeddingConfig(num_embeddings: int, embedding_dim: int, name: str = '', data_type: ~torchrec.types.DataType = DataType.FP32, feature_names: ~typing.List[str] = <factory>, weight_init_max: ~typing.Optional[float] = None, weight_init_min: ~typing.Optional[float] = None, num_embeddings_post_pruning: ~typing.Optional[int] = None, init_fn: ~typing.Optional[~typing.Callable[[~torch.Tensor], ~typing.Optional[~torch.Tensor]]] = None, need_pos: bool = False)

嵌入配置的基类。

参数:
  • num_embeddings (int) – 嵌入的数量。

  • embedding_dim (int) – 嵌入维度。

  • name (str) – 嵌入表的名称。

  • data_type (DataType) – 嵌入表的数据类型。

  • feature_names (List[str]) – 特征名称列表。

  • weight_init_max (Optional[float]) – 权重初始化的最大值。

  • weight_init_min (Optional[float]) – 权重初始化的最小值。

  • num_embeddings_post_pruning (Optional[int]) – 推理后嵌入的数量。如果为 None,则不应用剪枝。

  • init_fn (Optional[Callable[[torch.Tensor], Optional[torch.Tensor]]]) – 嵌入权重的初始化函数。

  • need_pos (bool) – 表是否按位置加权。

class torchrec.modules.embedding_modules.EmbeddingBagCollection(tables: List[EmbeddingBagConfig], is_weighted: bool = False, device: Optional[device] = None)

EmbeddingBagCollection 表示池化嵌入(EmbeddingBags)的集合。

注意

EmbeddingBagCollection 是一个未分片的模块,未针对性能进行优化。对于对性能敏感的场景,请考虑使用分片版本 ShardedEmbeddingBagCollection。

它处理 KeyedJaggedTensor 形式的稀疏数据,其值为 [F X B X L],其中

  • F:特征(键)

  • B:批次大小

  • L:稀疏特征的长度(交错)

并输出 KeyedTensor,其值为 [B * (F * D)],其中

  • F:特征(键)

  • D:每个特征(键)的嵌入维度

  • B:批次大小

参数:
  • tables (List[EmbeddingBagConfig]) – 嵌入表的列表。

  • is_weighted (bool) – 输入 KeyedJaggedTensor 是否加权。

  • 设备 (可选[torch.device]) – 默认计算设备。

示例

table_0 = EmbeddingBagConfig(
    name="t1", embedding_dim=3, num_embeddings=10, feature_names=["f1"]
)
table_1 = EmbeddingBagConfig(
    name="t2", embedding_dim=4, num_embeddings=10, feature_names=["f2"]
)

ebc = EmbeddingBagCollection(tables=[table_0, table_1])

#        0       1        2  <-- batch
# "f1"   [0,1] None    [2]
# "f2"   [3]    [4]    [5,6,7]
#  ^
# feature

features = KeyedJaggedTensor(
    keys=["f1", "f2"],
    values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]),
    offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]),
)

pooled_embeddings = ebc(features)
print(pooled_embeddings.values())
tensor([[-0.8899, -0.1342, -1.9060, -0.0905, -0.2814, -0.9369, -0.7783],
    [ 0.0000,  0.0000,  0.0000,  0.1598,  0.0695,  1.3265, -0.1011],
    [-0.4256, -1.1846, -2.1648, -1.0893,  0.3590, -1.9784, -0.7681]],
    grad_fn=<CatBackward0>)
print(pooled_embeddings.keys())
['f1', 'f2']
print(pooled_embeddings.offset_per_key())
tensor([0, 3, 7])
属性 device: device

返回:torch.device:计算设备。

embedding_bag_configs() List[EmbeddingBagConfig]
返回:

嵌入包配置。

返回类型:

List[EmbeddingBagConfig]

forward(features: KeyedJaggedTensor) KeyedTensor

运行 EmbeddingBagCollection 前向传递。此方法接收一个 KeyedJaggedTensor 并返回一个 KeyedTensor,它是每个特征嵌入池化的结果。

参数:

features (KeyedJaggedTensor) – 输入 KJT

返回:

KeyedTensor

is_weighted() bool
返回:

EmbeddingBagCollection 是否加权。

返回类型:

bool

reset_parameters() None

重置 EmbeddingBagCollection 的参数。如果存在,参数值将根据每个 EmbeddingBagConfig 的 init_fn 进行初始化。

torchrec.modules.embedding_modules.EmbeddingCollection(tables: List[EmbeddingConfig], device: Optional[device] = None, need_indices: bool = False)

EmbeddingCollection 表示非池化嵌入的集合。

注意

EmbeddingCollection 是一个非分片模块,并且没有针对性能进行优化。对于对性能敏感的场景,请考虑使用分片版本 ShardedEmbeddingCollection。

它处理形式为 KeyedJaggedTensor 的稀疏数据,形式为 [F X B X L],其中

  • F:特征(键)

  • B:批次大小

  • L:稀疏特征的长度(可变)

并输出 Dict[特征 (键),JaggedTensor]。每个 JaggedTensor 包含形式为 (B * L) X D 的值,其中

  • B:批次大小

  • L:稀疏特征的长度(交错)

  • D:每个特征(键)的嵌入维度,长度为 L

参数:
  • tables (List[EmbeddingConfig]) – 嵌入表的列表。

  • 设备 (可选[torch.device]) – 默认计算设备。

  • need_indices (bool) – 如果我们需要将索引传递到最终的查找字典。

示例

e1_config = EmbeddingConfig(
    name="t1", embedding_dim=3, num_embeddings=10, feature_names=["f1"]
)
e2_config = EmbeddingConfig(
    name="t2", embedding_dim=3, num_embeddings=10, feature_names=["f2"]
)

ec = EmbeddingCollection(tables=[e1_config, e2_config])

#     0       1        2  <-- batch
# 0   [0,1] None    [2]
# 1   [3]    [4]    [5,6,7]
# ^
# feature

features = KeyedJaggedTensor.from_offsets_sync(
    keys=["f1", "f2"],
    values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]),
    offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]),
)
feature_embeddings = ec(features)
print(feature_embeddings['f2'].values())
tensor([[-0.2050,  0.5478,  0.6054],
[ 0.7352,  0.3210, -3.0399],
[ 0.1279, -0.1756, -0.4130],
[ 0.7519, -0.4341, -0.0499],
[ 0.9329, -1.0697, -0.8095]], grad_fn=<EmbeddingBackward>)
属性 device: device

返回:torch.device:计算设备。

embedding_configs() List[EmbeddingConfig]
返回:

嵌入配置。

返回类型:

List[EmbeddingConfig]

embedding_dim() int
返回:

嵌入维度。

返回类型:

int

embedding_names_by_table() List[List[str]]
返回:

按表排列的嵌入名称。

返回类型:

List[List[str]]

forward(features: KeyedJaggedTensor) Dict[str, JaggedTensor]

运行 EmbeddingBagCollection 前向传递。此方法接收一个 KeyedJaggedTensor 并返回一个 Dict[str, JaggedTensor],它是每个特征的各个嵌入的结果。

参数:

features (KeyedJaggedTensor) – 形式为 [F X B X L] 的 KJT。

返回:

Dict[str, JaggedTensor]

need_indices() bool
返回:

EmbeddingCollection 是否需要索引。

返回类型:

bool

reset_parameters() None

重置 EmbeddingCollection 的参数。如果存在,参数值将根据每个 EmbeddingConfig 的 init_fn 进行初始化。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取初学者和高级开发人员的深度教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源