TorchRec 概念¶
在本节中,我们将学习 TorchRec 的关键概念,这些概念旨在优化使用 PyTorch 的大规模推荐系统。我们将详细了解每个概念的工作原理,并回顾它如何与 TorchRec 的其余部分一起使用。
TorchRec 具有其模块的特定输入/输出数据类型,以有效地表示稀疏特征,包括
JaggedTensor: 围绕长度/偏移量和值张量的包装器,用于表示单个稀疏特征。
KeyedJaggedTensor: 有效地表示多个稀疏特征,可以将其视为多个
JaggedTensor
。KeyedTensor:
torch.Tensor
的包装器,允许通过键访问张量值。
为了实现高性能和效率,规范的 torch.Tensor
在表示稀疏数据时效率极低。TorchRec 引入了这些新的数据类型,因为它们提供了稀疏输入数据的有效存储和表示。正如您稍后将看到的,KeyedJaggedTensor
使分布式环境中的输入数据通信非常高效,从而带来了 TorchRec 提供的关键性能优势之一。
在端到端训练循环中,TorchRec 包含以下主要组件
Planner: 接收嵌入表的配置、环境设置,并为模型生成优化的分片计划。
Sharder: 根据分片计划对模型进行分片,采用不同的分片策略,包括数据并行、表式、行式、表式-行式、列式和表式-列式分片。
DistributedModelParallel: 结合了 sharder 和优化器,并提供了以分布式方式训练模型的入口点。
JaggedTensor¶
JaggedTensor
通过长度、值和偏移量来表示稀疏特征。它之所以被称为“jagged”,是因为它可以有效地表示可变长度序列的数据。相比之下,规范的 torch.Tensor
假设每个序列具有相同的长度,但这在真实世界数据中通常不是这种情况。JaggedTensor
有助于表示此类数据而无需填充,从而使其非常高效。
主要组件
Lengths
: 一个整数列表,表示每个实体的元素数量。Offsets
: 一个整数列表,表示扁平化值张量中每个序列的起始索引。这些提供了长度的替代方案。Values
: 一个 1D 张量,包含每个实体的实际值,连续存储。
这是一个简单的示例,演示了每个组件的外观
# User interactions:
# - User 1 interacted with 2 items
# - User 2 interacted with 3 items
# - User 3 interacted with 1 item
lengths = [2, 3, 1]
offsets = [0, 2, 5] # Starting index of each user's interactions
values = torch.Tensor([101, 102, 201, 202, 203, 301]) # Item IDs interacted with
jt = JaggedTensor(lengths=lengths, values=values)
# OR
jt = JaggedTensor(offsets=offsets, values=values)
KeyedJaggedTensor¶
KeyedJaggedTensor
通过引入键(通常是特征名称)来标记不同的特征组(例如,用户特征和项目特征),从而扩展了 JaggedTensor
的功能。这是 EmbeddingBagCollection
和 EmbeddingCollection
的 forward
中使用的数据类型,因为它们用于表示表中的多个特征。
KeyedJaggedTensor
具有隐含的批次大小,即特征数量除以 lengths
张量的长度。下面的示例的批次大小为 2。与 JaggedTensor
类似,offsets
和 lengths
的功能相同。您还可以通过从 KeyedJaggedTensor
访问键来访问特征的 lengths
、offsets
和 values
。
keys = ["user_features", "item_features"]
# Lengths of interactions:
# - User features: 2 users, with 2 and 3 interactions respectively
# - Item features: 2 items, with 1 and 2 interactions respectively
lengths = [2, 3, 1, 2]
values = torch.Tensor([11, 12, 21, 22, 23, 101, 102, 201])
# Create a KeyedJaggedTensor
kjt = KeyedJaggedTensor(keys=keys, lengths=lengths, values=values)
# Access the features by key
print(kjt["user_features"])
# Outputs user features
print(kjt["item_features"])
Planner¶
TorchRec planner 帮助确定模型的最佳分片配置。它评估嵌入表分片的多种可能性,并优化性能。planner 执行以下操作
评估硬件的内存约束。
根据内存获取(例如嵌入查找)估算计算需求。
解决特定于数据的因素。
考虑其他硬件细节,例如带宽,以生成最佳分片计划。
为了确保准确考虑这些因素,Planner 可以结合有关嵌入表、约束、硬件信息和拓扑的数据,以帮助生成最佳计划。
使用 TorchRec 分片模块进行分布式训练¶
有许多可用的分片策略,我们如何确定使用哪一个?每种分片方案都有相关的成本,这与模型大小和 GPU 数量相结合,决定了哪种分片策略最适合模型。
在没有分片的情况下,每个 GPU 保留嵌入表的副本 (DP),主要成本是计算,其中每个 GPU 在前向传递中查找其内存中的嵌入向量,并在后向传递中更新梯度。
使用分片时,会增加通信成本:每个 GPU 都需要向其他 GPU 请求嵌入向量查找,并通信计算出的梯度。这通常被称为 all2all
通信。在 TorchRec 中,对于给定 GPU 上的输入数据,我们确定数据每个部分的嵌入分片所在的位置,并将其发送到目标 GPU。然后,目标 GPU 将嵌入向量返回给原始 GPU。在后向传递中,梯度被发送回目标 GPU,并且分片会通过优化器进行相应的更新。
如上所述,分片需要我们通信输入数据和嵌入查找。TorchRec 在三个主要阶段处理此问题,我们将此称为分片嵌入模块前向传递,该传递用于 TorchRec 模型的训练和推理
特征 All to All/输入分布 (
input_dist
)将输入数据(以
KeyedJaggedTensor
的形式)通信到包含相关嵌入表分片的适当设备
嵌入查找
使用特征 all to all 交换后形成的新输入数据查找嵌入
嵌入 All to All/输出分布 (
output_dist
)将嵌入查找数据通信回请求它的适当设备(根据设备接收到的输入数据)
后向传递执行相同的操作,但顺序相反。
下图演示了其工作原理

图 2:表式分片表的前向传递,包括分片 TorchRec 模块的 input_dist、lookup 和 output_dist¶
DistributedModelParallel¶
以上所有内容最终汇集成 TorchRec 用于分片和集成计划的主要入口点。在高层次上,DistributedModelParallel
执行以下操作
通过设置进程组和分配设备类型来初始化环境。
如果没有提供 sharder,则使用默认的 sharder,默认 sharder 包括
EmbeddingBagCollectionSharder
。接收提供的分片计划,如果未提供,则生成一个。
创建模块的分片版本,并用它们替换原始模块,例如,将
EmbeddingCollection
转换为ShardedEmbeddingCollection
。默认情况下,使用
DistributedDataParallel
包装DistributedModelParallel
,使模块既是模型并行又是数据并行。
Optimizer¶
TorchRec 模块提供了一个无缝 API,用于在训练中融合后向传递和优化器步骤,从而显着优化性能并减少使用的内存,同时还可以在将不同的优化器分配给不同的模型参数方面提供粒度。

图 3:将嵌入后向传递与稀疏优化器融合¶
推理¶
推理环境与训练环境不同,它们对性能和模型大小非常敏感。TorchRec 推理优化的两个主要区别是
量化: 推理模型经过量化以实现更低的延迟和更小的模型大小。此优化使我们能够使用尽可能少的设备进行推理,从而最大限度地减少延迟。
C++ 环境: 为了进一步最大限度地减少延迟,模型在 C++ 环境中运行。
TorchRec 提供了以下内容,以将 TorchRec 模型转换为可用于推理的模型
用于量化模型的 API,包括使用 FBGEMM TBE 自动进行的优化
用于分布式推理的分片嵌入
将模型编译为 TorchScript(与 C++ 兼容)