注意
点击 此处 下载完整的示例代码
TorchRec 简介¶
TorchRec 是一个 PyTorch 库,专门用于使用嵌入构建可扩展且高效的推荐系统。本教程将引导您完成安装过程,介绍嵌入的概念,并强调它们在推荐系统中的重要性。它提供了使用 PyTorch 和 TorchRec 实现嵌入的实际演示,重点介绍如何通过分布式训练和高级优化来处理大型嵌入表。
嵌入的基础知识及其在推荐系统中的作用
如何在 PyTorch 环境中设置 TorchRec 以管理和实现嵌入
探索将大型嵌入表分布到多个 GPU 上的高级技术
PyTorch v2.5 或更高版本,以及 CUDA 11.8 或更高版本
Python 3.9 或更高版本
安装依赖项¶
在 Google Colab 或其他环境中运行本教程之前,请安装以下依赖项
!pip3 install --pre torch --index-url https://download.pytorch.org/whl/cu121 -U
!pip3 install fbgemm_gpu --index-url https://download.pytorch.org/whl/cu121
!pip3 install torchmetrics==1.0.3
!pip3 install torchrec --index-url https://download.pytorch.org/whl/cu121
注意
如果您在 Google Colab 中运行此教程,请确保切换到 GPU 运行时类型。有关更多信息,请参阅 启用 CUDA
嵌入¶
在构建推荐系统时,类别特征通常具有巨大的基数,例如帖子、用户、广告等等。
为了表示这些实体并建模这些关系,使用**嵌入**。在机器学习中,**嵌入是高维空间中实数的向量,用于表示复杂数据(如单词、图像或用户)中的含义**。
推荐系统中的嵌入¶
现在您可能想知道,这些嵌入最初是如何生成的?嗯,嵌入表示为**嵌入表**中的单独行,也称为嵌入权重。这样做的原因是,嵌入或嵌入表权重就像模型的所有其他权重一样,通过梯度下降进行训练!
嵌入表只是用于存储嵌入的大矩阵,具有两个维度(B,N),其中
B 是表存储的嵌入数量
N 是每个嵌入的维度数(N 维嵌入)。
嵌入表的输入表示嵌入查找,以检索特定索引或行的嵌入。在许多大型系统中使用的推荐系统中,唯一 ID 不仅用于特定用户,还用于帖子和广告等实体,以用作相应嵌入表的查找索引!
推荐系统中通过以下过程训练嵌入
**输入/查找索引作为唯一 ID 馈送到模型中**。ID 会被哈希到嵌入表的总大小,以防止 ID > 行数时出现问题
然后检索嵌入并**进行池化,例如取嵌入的总和或平均值**。这是必需的,因为每个示例的嵌入数量可能可变,而模型期望一致的形状。
**嵌入与模型的其余部分一起用于生成预测**,例如广告的点击率 (CTR)。
使用预测和示例的标签计算损失,并**通过梯度下降和反向传播更新模型的所有权重,包括与示例关联的嵌入权重**。
这些嵌入对于表示类别特征(例如用户、帖子和广告)以捕获关系并做出良好的推荐至关重要。 深度学习推荐模型 (DLRM) 论文更详细地讨论了在推荐系统中使用嵌入表的技术细节。
本教程介绍了嵌入的概念,展示了 TorchRec 特定的模块和数据类型,并描述了 TorchRec 如何进行分布式训练。
import torch
PyTorch 中的嵌入¶
在 PyTorch 中,我们有以下类型的嵌入
torch.nn.Embedding
:一个嵌入表,其中前向传递按原样返回嵌入本身。torch.nn.EmbeddingBag
:嵌入表,其中前向传递返回随后被池化的嵌入,例如总和或平均值,也称为**池化嵌入**。
在本节中,我们将简要介绍如何通过将索引传递到表中来执行嵌入查找。
num_embeddings, embedding_dim = 10, 4
# Initialize our embedding table
weights = torch.rand(num_embeddings, embedding_dim)
print("Weights:", weights)
# Pass in pre-generated weights just for example, typically weights are randomly initialized
embedding_collection = torch.nn.Embedding(
num_embeddings, embedding_dim, _weight=weights
)
embedding_bag_collection = torch.nn.EmbeddingBag(
num_embeddings, embedding_dim, _weight=weights
)
# Print out the tables, we should see the same weights as above
print("Embedding Collection Table: ", embedding_collection.weight)
print("Embedding Bag Collection Table: ", embedding_bag_collection.weight)
# Lookup rows (ids for embedding ids) from the embedding tables
# 2D tensor with shape (batch_size, ids for each batch)
ids = torch.tensor([[1, 3]])
print("Input row IDS: ", ids)
embeddings = embedding_collection(ids)
# Print out the embedding lookups
# You should see the specific embeddings be the same as the rows (ids) of the embedding tables above
print("Embedding Collection Results: ")
print(embeddings)
print("Shape: ", embeddings.shape)
# ``nn.EmbeddingBag`` default pooling is mean, so should be mean of batch dimension of values above
pooled_embeddings = embedding_bag_collection(ids)
print("Embedding Bag Collection Results: ")
print(pooled_embeddings)
print("Shape: ", pooled_embeddings.shape)
# ``nn.EmbeddingBag`` is the same as ``nn.Embedding`` but just with pooling (mean, sum, and so on)
# We can see that the mean of the embeddings of embedding_collection is the same as the output of the embedding_bag_collection
print("Mean: ", torch.mean(embedding_collection(ids), dim=1))
Weights: tensor([[0.8823, 0.9150, 0.3829, 0.9593],
[0.3904, 0.6009, 0.2566, 0.7936],
[0.9408, 0.1332, 0.9346, 0.5936],
[0.8694, 0.5677, 0.7411, 0.4294],
[0.8854, 0.5739, 0.2666, 0.6274],
[0.2696, 0.4414, 0.2969, 0.8317],
[0.1053, 0.2695, 0.3588, 0.1994],
[0.5472, 0.0062, 0.9516, 0.0753],
[0.8860, 0.5832, 0.3376, 0.8090],
[0.5779, 0.9040, 0.5547, 0.3423]])
Embedding Collection Table: Parameter containing:
tensor([[0.8823, 0.9150, 0.3829, 0.9593],
[0.3904, 0.6009, 0.2566, 0.7936],
[0.9408, 0.1332, 0.9346, 0.5936],
[0.8694, 0.5677, 0.7411, 0.4294],
[0.8854, 0.5739, 0.2666, 0.6274],
[0.2696, 0.4414, 0.2969, 0.8317],
[0.1053, 0.2695, 0.3588, 0.1994],
[0.5472, 0.0062, 0.9516, 0.0753],
[0.8860, 0.5832, 0.3376, 0.8090],
[0.5779, 0.9040, 0.5547, 0.3423]], requires_grad=True)
Embedding Bag Collection Table: Parameter containing:
tensor([[0.8823, 0.9150, 0.3829, 0.9593],
[0.3904, 0.6009, 0.2566, 0.7936],
[0.9408, 0.1332, 0.9346, 0.5936],
[0.8694, 0.5677, 0.7411, 0.4294],
[0.8854, 0.5739, 0.2666, 0.6274],
[0.2696, 0.4414, 0.2969, 0.8317],
[0.1053, 0.2695, 0.3588, 0.1994],
[0.5472, 0.0062, 0.9516, 0.0753],
[0.8860, 0.5832, 0.3376, 0.8090],
[0.5779, 0.9040, 0.5547, 0.3423]], requires_grad=True)
Input row IDS: tensor([[1, 3]])
Embedding Collection Results:
tensor([[[0.3904, 0.6009, 0.2566, 0.7936],
[0.8694, 0.5677, 0.7411, 0.4294]]], grad_fn=<EmbeddingBackward0>)
Shape: torch.Size([1, 2, 4])
Embedding Bag Collection Results:
tensor([[0.6299, 0.5843, 0.4988, 0.6115]], grad_fn=<EmbeddingBagBackward0>)
Shape: torch.Size([1, 4])
Mean: tensor([[0.6299, 0.5843, 0.4988, 0.6115]], grad_fn=<MeanBackward1>)
恭喜!现在您已经对如何使用嵌入表有了基本的了解——这是现代推荐系统基础之一!这些表表示实体及其关系。例如,给定用户与其喜欢的页面和帖子之间的关系。
TorchRec 功能概述¶
在上一节中,我们学习了如何使用嵌入表,这是现代推荐系统基础之一!这些表表示实体和关系,例如用户、页面、帖子等。鉴于这些实体总是在增加,因此通常会应用**哈希**函数以确保 ID 在特定嵌入表的范围内。但是,为了表示大量实体并减少哈希冲突,这些表可能会变得非常庞大(例如,考虑广告的数量)。事实上,这些表可能会变得非常庞大,以至于即使使用 80G 内存也无法容纳在一个 GPU 上。
为了使用庞大的嵌入表训练模型,需要将这些表跨 GPU 分片,这就会在并行性和优化方面带来一整套新的问题和机遇。幸运的是,我们拥有 TorchRec 库,它已经遇到、整合并解决了其中许多问题。TorchRec 充当**提供大规模分布式嵌入原语的库**。
接下来,我们将探索 TorchRec 库的主要功能。我们将从torch.nn.Embedding
开始,并将其扩展到自定义 TorchRec 模块,探索生成嵌入分片计划的分布式训练环境,查看 TorchRec 的固有优化,并将模型扩展到在 C++ 中进行推理。以下是本节内容的简要概述
TorchRec 模块和数据类型
分布式训练、分片和优化
推理
让我们从导入 TorchRec 开始
import torchrec
本节介绍 TorchRec 模块和数据类型,包括EmbeddingCollection
和EmbeddingBagCollection
、JaggedTensor
、KeyedJaggedTensor
、KeyedTensor
等实体。
从EmbeddingBag
到EmbeddingBagCollection
¶
我们已经探讨了torch.nn.Embedding
和torch.nn.EmbeddingBag
。TorchRec 通过创建嵌入集合来扩展这些模块,换句话说,可以使用EmbeddingCollection
和EmbeddingBagCollection
拥有多个嵌入表的模块。我们将使用EmbeddingBagCollection
来表示一组嵌入包。
在下面的示例代码中,我们创建了一个EmbeddingBagCollection
(EBC),其中包含两个嵌入包,一个表示**产品**,一个表示**用户**。每个表product_table
和user_table
都由大小为 4096 的 64 维嵌入表示。
ebc = torchrec.EmbeddingBagCollection(
device="cpu",
tables=[
torchrec.EmbeddingBagConfig(
name="product_table",
embedding_dim=64,
num_embeddings=4096,
feature_names=["product"],
pooling=torchrec.PoolingType.SUM,
),
torchrec.EmbeddingBagConfig(
name="user_table",
embedding_dim=64,
num_embeddings=4096,
feature_names=["user"],
pooling=torchrec.PoolingType.SUM,
)
]
)
print(ebc.embedding_bags)
ModuleDict(
(product_table): EmbeddingBag(4096, 64, mode='sum')
(user_table): EmbeddingBag(4096, 64, mode='sum')
)
让我们检查EmbeddingBagCollection
的前向方法以及模块的输入和输出
import inspect
# Let's look at the ``EmbeddingBagCollection`` forward method
# What is a ``KeyedJaggedTensor`` and ``KeyedTensor``?
print(inspect.getsource(ebc.forward))
def forward(self, features: KeyedJaggedTensor) -> KeyedTensor:
"""
Run the EmbeddingBagCollection forward pass. This method takes in a `KeyedJaggedTensor`
and returns a `KeyedTensor`, which is the result of pooling the embeddings for each feature.
Args:
features (KeyedJaggedTensor): Input KJT
Returns:
KeyedTensor
"""
flat_feature_names: List[str] = []
for names in self._feature_names:
flat_feature_names.extend(names)
inverse_indices = reorder_inverse_indices(
inverse_indices=features.inverse_indices_or_none(),
feature_names=flat_feature_names,
)
pooled_embeddings: List[torch.Tensor] = []
feature_dict = features.to_dict()
for i, embedding_bag in enumerate(self.embedding_bags.values()):
for feature_name in self._feature_names[i]:
f = feature_dict[feature_name]
res = embedding_bag(
input=f.values(),
offsets=f.offsets(),
per_sample_weights=f.weights() if self._is_weighted else None,
).float()
pooled_embeddings.append(res)
return KeyedTensor(
keys=self._embedding_names,
values=process_pooled_embeddings(
pooled_embeddings=pooled_embeddings,
inverse_indices=inverse_indices,
),
length_per_key=self._lengths_per_embedding,
)
TorchRec 输入/输出数据类型¶
TorchRec 具有用于其模块输入和输出的不同数据类型:JaggedTensor
、KeyedJaggedTensor
和KeyedTensor
。现在您可能会问,为什么要创建新的数据类型来表示稀疏特征?为了回答这个问题,我们必须了解稀疏特征如何在代码中表示。
稀疏特征也称为id_list_feature
和id_score_list_feature
,是将用作嵌入表索引的**ID**,以检索该 ID 的嵌入。举一个非常简单的例子,想象一个单一的稀疏特征是用户交互过的广告。输入本身将是一组用户交互过的广告 ID,检索到的嵌入将是这些广告的语义表示。在代码中表示这些特征的棘手之处在于,在每个输入示例中,**ID 的数量是可变的**。有一天,用户可能只与一个广告互动,而第二天他们可能与三个广告互动。
一个简单的表示如下所示,其中我们有一个lengths
张量表示批次中每个示例中有多少个索引,以及一个包含索引本身的values
张量。
# Batch Size 2
# 1 ID in example 1, 2 IDs in example 2
id_list_feature_lengths = torch.tensor([1, 2])
# Values (IDs) tensor: ID 5 is in example 1, ID 7, 1 is in example 2
id_list_feature_values = torch.tensor([5, 7, 1])
接下来,让我们看看偏移量以及每个批次中包含的内容
# Lengths can be converted to offsets for easy indexing of values
id_list_feature_offsets = torch.cumsum(id_list_feature_lengths, dim=0)
print("Offsets: ", id_list_feature_offsets)
print("First Batch: ", id_list_feature_values[: id_list_feature_offsets[0]])
print(
"Second Batch: ",
id_list_feature_values[id_list_feature_offsets[0] : id_list_feature_offsets[1]],
)
from torchrec import JaggedTensor
# ``JaggedTensor`` is just a wrapper around lengths/offsets and values tensors!
jt = JaggedTensor(values=id_list_feature_values, lengths=id_list_feature_lengths)
# Automatically compute offsets from lengths
print("Offsets: ", jt.offsets())
# Convert to list of values
print("List of Values: ", jt.to_dense())
# ``__str__`` representation
print(jt)
from torchrec import KeyedJaggedTensor
# ``JaggedTensor`` represents IDs for 1 feature, but we have multiple features in an ``EmbeddingBagCollection``
# That's where ``KeyedJaggedTensor`` comes in! ``KeyedJaggedTensor`` is just multiple ``JaggedTensors`` for multiple id_list_feature_offsets
# From before, we have our two features "product" and "user". Let's create ``JaggedTensors`` for both!
product_jt = JaggedTensor(
values=torch.tensor([1, 2, 1, 5]), lengths=torch.tensor([3, 1])
)
user_jt = JaggedTensor(values=torch.tensor([2, 3, 4, 1]), lengths=torch.tensor([2, 2]))
# Q1: How many batches are there, and which values are in the first batch for ``product_jt`` and ``user_jt``?
kjt = KeyedJaggedTensor.from_jt_dict({"product": product_jt, "user": user_jt})
# Look at our feature keys for the ``KeyedJaggedTensor``
print("Keys: ", kjt.keys())
# Look at the overall lengths for the ``KeyedJaggedTensor``
print("Lengths: ", kjt.lengths())
# Look at all values for ``KeyedJaggedTensor``
print("Values: ", kjt.values())
# Can convert ``KeyedJaggedTensor`` to dictionary representation
print("to_dict: ", kjt.to_dict())
# ``KeyedJaggedTensor`` string representation
print(kjt)
# Q2: What are the offsets for the ``KeyedJaggedTensor``?
# Now we can run a forward pass on our ``EmbeddingBagCollection`` from before
result = ebc(kjt)
result
# Result is a ``KeyedTensor``, which contains a list of the feature names and the embedding results
print(result.keys())
# The results shape is [2, 128], as batch size of 2. Reread previous section if you need a refresher on how the batch size is determined
# 128 for dimension of embedding. If you look at where we initialized the ``EmbeddingBagCollection``, we have two tables "product" and "user" of dimension 64 each
# meaning embeddings for both features are of size 64. 64 + 64 = 128
print(result.values().shape)
# Nice to_dict method to determine the embeddings that belong to each feature
result_dict = result.to_dict()
for key, embedding in result_dict.items():
print(key, embedding.shape)
Offsets: tensor([1, 3])
First Batch: tensor([5])
Second Batch: tensor([7, 1])
Offsets: tensor([0, 1, 3])
List of Values: [tensor([5]), tensor([7, 1])]
JaggedTensor({
[[5], [7, 1]]
})
Keys: ['product', 'user']
Lengths: tensor([3, 1, 2, 2])
Values: tensor([1, 2, 1, 5, 2, 3, 4, 1])
to_dict: {'product': <torchrec.sparse.jagged_tensor.JaggedTensor object at 0x7efa0af27e80>, 'user': <torchrec.sparse.jagged_tensor.JaggedTensor object at 0x7efa0af25720>}
KeyedJaggedTensor({
"product": [[1, 2, 1], [5]],
"user": [[2, 3], [4, 1]]
})
['product', 'user']
torch.Size([2, 128])
product torch.Size([2, 64])
user torch.Size([2, 64])
恭喜!您现在了解了 TorchRec 模块和数据类型。为自己能够坚持到这一步而感到自豪。接下来,我们将学习分布式训练和分片。
分布式训练和分片¶
现在我们已经掌握了 TorchRec 模块和数据类型,是时候将其提升到一个新的水平了。
请记住,TorchRec 的主要目的是提供分布式嵌入的原语。到目前为止,我们只在单个设备上使用嵌入表。考虑到嵌入表的大小,这已经成为可能,但在生产环境中,情况通常并非如此。嵌入表通常会变得非常庞大,一个表无法容纳在一个 GPU 上,从而需要多个设备和分布式环境。
在本节中,我们将探索如何设置分布式环境,以及实际生产训练是如何完成的,并探索使用 TorchRec 对嵌入表进行分片。
本节也只会使用 1 个 GPU,尽管它将以分布式方式进行处理。这只是训练的限制,因为训练每个 GPU 都有一个进程。推理不会遇到此要求
在下面的示例代码中,我们设置了 PyTorch 分布式环境。
警告
如果您在 Google Colab 中运行此代码,则只能调用此单元格一次,再次调用它会导致错误,因为您只能初始化一次进程组。
import os
import torch.distributed as dist
# Set up environment variables for distributed training
# RANK is which GPU we are on, default 0
os.environ["RANK"] = "0"
# How many devices in our "world", colab notebook can only handle 1 process
os.environ["WORLD_SIZE"] = "1"
# Localhost as we are training locally
os.environ["MASTER_ADDR"] = "localhost"
# Port for distributed training
os.environ["MASTER_PORT"] = "29500"
# nccl backend is for GPUs, gloo is for CPUs
dist.init_process_group(backend="gloo")
print(f"Distributed environment initialized: {dist}")
Distributed environment initialized: <module 'torch.distributed' from '/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/distributed/__init__.py'>
分布式嵌入¶
我们已经使用过主要的 TorchRec 模块:EmbeddingBagCollection
。我们已经检查了它的工作原理以及 TorchRec 中数据是如何表示的。但是,我们还没有探索 TorchRec 的主要部分之一,即**分布式嵌入**。
GPU 是当今迄今为止最流行的机器学习工作负载选择,因为它们能够执行比 CPU 多几个数量级的浮点运算/秒 (FLOPs)。但是,GPU 受到稀缺的快速内存(HBM,类似于 CPU 的 RAM)的限制,通常约为 ~10 GB。
一个推荐系统 (RecSys) 模型可能包含嵌入表,其大小远远超过单个 GPU 的内存限制,因此需要将嵌入表分布到多个 GPU 上,这也被称为**模型并行** (model parallel)。另一方面,**数据并行** (data parallel) 是指在每个 GPU 上复制整个模型,每个 GPU 接收不同的数据批次进行训练,并在反向传播过程中同步梯度。
模型中**计算量较小但内存占用较大的部分(嵌入)使用模型并行进行分布**,而**计算量较大但内存占用较小的部分(密集层、MLP 等)使用数据并行进行分布**。
分片¶
为了分布嵌入表,我们将嵌入表分割成多个部分,并将这些部分放置到不同的设备上,这也被称为“分片”(sharding)。
分片嵌入表的方法有很多。最常见的方法有:
表级分片 (Table-Wise):整个表放置在一个设备上
列级分片 (Column-Wise):将嵌入表的列进行分片
行级分片 (Row-Wise):将嵌入表的行进行分片
分片模块¶
虽然所有这些看起来都需要处理和实现很多内容,但幸运的是,**TorchRec 提供了所有用于轻松进行分布式训练和推理的原语**!实际上,TorchRec 模块有两个对应的类,用于在分布式环境中使用任何 TorchRec 模块
**模块分片器 (The module sharder)**:此类公开了一个
shard
API,用于处理 TorchRec 模块的分片,生成一个分片模块。* 对于EmbeddingBagCollection
,分片器是 EmbeddingBagCollectionSharder**分片模块 (Sharded module)**:此类是 TorchRec 模块的分片变体。它与常规 TorchRec 模块具有相同的输入/输出,但经过优化,可以在分布式环境中工作。* 对于
EmbeddingBagCollection
,分片变体是 ShardedEmbeddingBagCollection
每个 TorchRec 模块都有一个未分片和分片版本。
未分片版本用于原型设计和实验。
分片版本用于分布式环境中的分布式训练和推理。
TorchRec 模块的分片版本,例如 EmbeddingBagCollection
,将处理模型并行所需的一切,例如 GPU 之间用于将嵌入分布到正确 GPU 的通信。
回顾我们的 EmbeddingBagCollection
模块
ebc
from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
from torchrec.distributed.types import ShardingEnv
# Corresponding sharder for ``EmbeddingBagCollection`` module
sharder = EmbeddingBagCollectionSharder()
# ``ProcessGroup`` from torch.distributed initialized 2 cells above
pg = dist.GroupMember.WORLD
assert pg is not None, "Process group is not initialized"
print(f"Process Group: {pg}")
Process Group: <torch.distributed.distributed_c10d.ProcessGroup object at 0x7efa0d6c9970>
规划器¶
在我们展示分片如何工作之前,我们必须了解**规划器** (planner),它可以帮助我们确定最佳的分片配置。
给定一定数量的嵌入表和一定数量的进程 (rank),可能存在许多不同的分片配置。例如,给定 2 个嵌入表和 2 个 GPU,您可以:
将 1 个表放置在每个 GPU 上
将两个表都放置在一个 GPU 上,另一个 GPU 上不放置任何表
将某些行和列放置在每个 GPU 上
鉴于所有这些可能性,我们通常希望获得一个对性能最佳的分片配置。
这就是规划器发挥作用的地方。规划器能够根据嵌入表的数量和 GPU 的数量确定最佳配置。事实证明,手动执行此操作非常困难,工程师必须考虑大量因素才能确保最佳的分片计划。幸运的是,当使用规划器时,TorchRec 提供了一个自动规划器。
TorchRec 规划器
评估硬件的内存限制
根据嵌入查找时的内存获取来估算计算量
解决数据特定因素
考虑其他硬件细节,如带宽,以生成最佳的分片计划
为了考虑所有这些变量,TorchRec 规划器可以接收 嵌入表、约束、硬件信息和拓扑结构的各种数据,以帮助生成模型的最佳分片计划,这些数据通常在各个层级提供。
要了解有关分片的更多信息,请参阅我们的 分片教程。
# In our case, 1 GPU and compute on CUDA device
planner = EmbeddingShardingPlanner(
topology=Topology(
world_size=1,
compute_device="cuda",
)
)
# Run planner to get plan for sharding
plan = planner.collective_plan(ebc, [sharder], pg)
print(f"Sharding Plan generated: {plan}")
Sharding Plan generated: module:
param | sharding type | compute kernel | ranks
------------- | ------------- | -------------- | -----
product_table | table_wise | fused | [0]
user_table | table_wise | fused | [0]
param | shard offsets | shard sizes | placement
------------- | ------------- | ----------- | -------------
product_table | [0, 0] | [4096, 64] | rank:0/cuda:0
user_table | [0, 0] | [4096, 64] | rank:0/cuda:0
规划器结果¶
如上所示,在运行规划器时,会有相当多的输出。我们可以看到许多统计数据正在计算,以及我们的表最终放置的位置。
运行规划器的结果是一个静态计划,可以重复用于分片!这使得分片可以针对生产模型保持静态,而不是每次都确定新的分片计划。下面,我们使用分片计划最终生成我们的 ShardedEmbeddingBagCollection
。
# The static plan that was generated
plan
env = ShardingEnv.from_process_group(pg)
# Shard the ``EmbeddingBagCollection`` module using the ``EmbeddingBagCollectionSharder``
sharded_ebc = sharder.shard(ebc, plan.plan[""], env, torch.device("cuda"))
print(f"Sharded EBC Module: {sharded_ebc}")
Sharded EBC Module: ShardedEmbeddingBagCollection(
(lookups):
GroupedPooledEmbeddingsLookup(
(_emb_modules): ModuleList(
(0): BatchedFusedEmbeddingBag(
(_emb_module): SplitTableBatchedEmbeddingBagsCodegen()
)
)
)
(_output_dists):
TwPooledEmbeddingDist()
(embedding_bags): ModuleDict(
(product_table): Module()
(user_table): Module()
)
)
使用 LazyAwaitable
进行 GPU 训练¶
请记住,TorchRec 是一个针对分布式嵌入的高度优化的库。TorchRec 引入的一个概念是为了提高 GPU 训练的性能,那就是 LazyAwaitable。您将在各种分片 TorchRec 模块的输出中看到 LazyAwaitable
类型。LazyAwaitable
类型所做的只是尽可能延迟计算某些结果,它通过像异步类型一样工作来实现这一点。
from typing import List
from torchrec.distributed.types import LazyAwaitable
# Demonstrate a ``LazyAwaitable`` type:
class ExampleAwaitable(LazyAwaitable[torch.Tensor]):
def __init__(self, size: List[int]) -> None:
super().__init__()
self._size = size
def _wait_impl(self) -> torch.Tensor:
return torch.ones(self._size)
awaitable = ExampleAwaitable([3, 2])
awaitable.wait()
kjt = kjt.to("cuda")
output = sharded_ebc(kjt)
# The output of our sharded ``EmbeddingBagCollection`` module is an `Awaitable`?
print(output)
kt = output.wait()
# Now we have our ``KeyedTensor`` after calling ``.wait()``
# If you are confused as to why we have a ``KeyedTensor ``output,
# give yourself a refresher on the unsharded ``EmbeddingBagCollection`` module
print(type(kt))
print(kt.keys())
print(kt.values().shape)
# Same output format as unsharded ``EmbeddingBagCollection``
result_dict = kt.to_dict()
for key, embedding in result_dict.items():
print(key, embedding.shape)
<torchrec.distributed.embeddingbag.EmbeddingBagCollectionAwaitable object at 0x7efa0dac0940>
<class 'torchrec.sparse.jagged_tensor.KeyedTensor'>
['product', 'user']
torch.Size([2, 128])
product torch.Size([2, 64])
user torch.Size([2, 64])
分片 TorchRec 模块的结构¶
我们现在已经成功地根据生成的计划对 EmbeddingBagCollection
进行了分片!分片模块具有来自 TorchRec 的通用 API,这些 API 隐藏了多个 GPU 之间的分布式通信/计算。实际上,这些 API 针对训练和推理中的性能进行了高度优化。**以下是 TorchRec 提供的用于分布式训练/推理的三个通用 API**
input_dist
:处理将输入从一个 GPU 分布到另一个 GPU。lookups
:使用 FBGEMM TBE 以优化的批处理方式执行实际的嵌入查找(稍后详细介绍)。output_dist
:处理将输出从一个 GPU 分布到另一个 GPU。
输入和输出的分布是通过 NCCL 集体通信 完成的,特别是 全对全通信 (All-to-Alls),其中所有 GPU 彼此之间发送和接收数据。TorchRec 与 PyTorch 分布式进行集体通信交互,并为最终用户提供简洁的抽象,从而消除了对底层细节的关注。
反向传播执行所有这些集体通信,但顺序相反,用于分布梯度。input_dist
、lookup
和 output_dist
都依赖于分片方案。由于我们以表级方式进行分片,因此这些 API 是由 TwPooledEmbeddingSharding 构造的模块。
sharded_ebc
# Distribute input KJTs to all other GPUs and receive KJTs
sharded_ebc._input_dists
# Distribute output embeddings to all other GPUs and receive embeddings
sharded_ebc._output_dists
[TwPooledEmbeddingDist(
(_dist): PooledEmbeddingsAllToAll()
)]
优化嵌入查找¶
在执行嵌入表集合的查找时,一个简单的解决方案是遍历所有 nn.EmbeddingBags
并对每个表执行一次查找。这正是标准的未分片 EmbeddingBagCollection
所做的。但是,虽然此解决方案简单,但速度非常慢。
FBGEMM 是一个提供 GPU 运算符(也称为内核)的库,这些运算符经过高度优化。其中一个运算符称为**表批处理嵌入** (TBE),提供了两个主要优化:
表批处理,允许您使用一个内核调用查找多个嵌入。
优化器融合,允许模块根据规范的 PyTorch 优化器和参数进行自我更新。
ShardedEmbeddingBagCollection
使用 FBGEMM TBE 而不是传统的 nn.EmbeddingBags
来执行优化的嵌入查找。
sharded_ebc._lookups
[GroupedPooledEmbeddingsLookup(
(_emb_modules): ModuleList(
(0): BatchedFusedEmbeddingBag(
(_emb_module): SplitTableBatchedEmbeddingBagsCodegen()
)
)
)]
DistributedModelParallel
¶
我们现在已经探索了对单个 EmbeddingBagCollection
进行分片!我们能够使用 EmbeddingBagCollectionSharder
和未分片的 EmbeddingBagCollection
生成一个 ShardedEmbeddingBagCollection
模块。此工作流程很好,但通常在实现模型并行时,DistributedModelParallel (DMP) 用作标准接口。当使用 DMP 包装您的模型(在本例中为 ebc
)时,将发生以下情况:
确定如何对模型进行分片。DMP 将收集可用的分片器,并制定一个关于如何以最佳方式对嵌入表 (例如
EmbeddingBagCollection
) 进行分片的计划。实际对模型进行分片。这包括在适当的设备上为每个嵌入表分配内存。
DMP 接收了我们刚刚试验的所有内容,例如静态分片计划、分片器列表等。但是,它也有一些不错的默认设置,可以无缝地对 TorchRec 模型进行分片。在此玩具示例中,由于我们有两个嵌入表和一个 GPU,因此 TorchRec 将两个表都放置在单个 GPU 上。
ebc
model = torchrec.distributed.DistributedModelParallel(ebc, device=torch.device("cuda"))
out = model(kjt)
out.wait()
model
WARNING:root:Could not determine LOCAL_WORLD_SIZE from environment, falling back to WORLD_SIZE.
DistributedModelParallel(
(_dmp_wrapped_module): ShardedEmbeddingBagCollection(
(lookups):
GroupedPooledEmbeddingsLookup(
(_emb_modules): ModuleList(
(0): BatchedFusedEmbeddingBag(
(_emb_module): SplitTableBatchedEmbeddingBagsCodegen()
)
)
)
(_input_dists):
TwSparseFeaturesDist(
(_dist): KJTAllToAll()
)
(_output_dists):
TwPooledEmbeddingDist(
(_dist): PooledEmbeddingsAllToAll()
)
(embedding_bags): ModuleDict(
(product_table): Module()
(user_table): Module()
)
)
)
分片最佳实践¶
目前,我们的配置仅在一个 GPU(或进程)上进行分片,这很简单:只需将所有表都放置在一个 GPU 的内存中。但是,在实际的生产用例中,嵌入表**通常分布在数百个 GPU 上**,并使用不同的分片方法,如表级、行级和列级。确定合适的分片配置(以防止内存不足问题)并保持其平衡(不仅在内存方面,而且在计算方面)以获得最佳性能,这一点非常重要。
添加优化器¶
请记住,TorchRec 模块针对大规模分布式训练进行了高度优化。一个重要的优化与优化器有关。
TorchRec 模块提供了一个无缝的 API 来融合训练中的反向传播和优化步骤,显著提升了性能并减少了内存使用,同时还提供了将不同的优化器分配给不同的模型参数的灵活性。
优化器类¶
TorchRec 使用 CombinedOptimizer
,它包含一系列 KeyedOptimizers
。一个 CombinedOptimizer
有效地简化了处理模型中各个子组的多个优化器。一个 KeyedOptimizer
扩展了 torch.optim.Optimizer
,并通过参数字典进行初始化,公开参数。在 EmbeddingBagCollection
中的每个 TBE
模块都将拥有自己的 KeyedOptimizer
,它们组合成一个 CombinedOptimizer
。
TorchRec 中的融合优化器¶
使用 DistributedModelParallel
,**优化器被融合,这意味着优化器更新是在反向传播中完成的**。这是 TorchRec 和 FBGEMM 中的一种优化,其中优化器嵌入梯度不会被具体化,而是直接应用于参数。这带来了显著的内存节省,因为嵌入梯度通常与参数本身的大小相同。
但是,您可以选择使优化器 dense
,它不应用此优化,并允许您检查嵌入梯度或根据需要对其应用计算。在这种情况下,密集优化器将是您带有优化器的标准 PyTorch 模型训练循环。
通过 DistributedModelParallel
创建优化器后,您仍然需要为与 TorchRec 嵌入模块无关的其他参数管理一个优化器。要查找其他参数,请使用 in_backward_optimizer_filter(model.named_parameters())
。像使用普通的 Torch 优化器一样将优化器应用于这些参数,并将此优化器与 model.fused_optimizer
组合成一个 CombinedOptimizer
,您可以在训练循环中使用它来执行 zero_grad
和 step
操作。
向 EmbeddingBagCollection
添加优化器¶
我们将通过两种等效的方法来实现这一点,但根据您的偏好提供不同的选项。
通过分片器中的
fused_params
传递优化器关键字参数。通过
apply_optimizer_in_backward
,它将优化器参数转换为fused_params
以传递给EmbeddingBagCollection
或EmbeddingCollection
中的TBE
。
# Option 1: Passing optimizer kwargs through fused parameters
from torchrec.optim.optimizers import in_backward_optimizer_filter
from fbgemm_gpu.split_embedding_configs import EmbOptimType
# We initialize the sharder with
fused_params = {
"optimizer": EmbOptimType.EXACT_ROWWISE_ADAGRAD,
"learning_rate": 0.02,
"eps": 0.002,
}
# Initialize sharder with ``fused_params``
sharder_with_fused_params = EmbeddingBagCollectionSharder(fused_params=fused_params)
# We'll use same plan and unsharded EBC as before but this time with our new sharder
sharded_ebc_fused_params = sharder_with_fused_params.shard(ebc, plan.plan[""], env, torch.device("cuda"))
# Looking at the optimizer of each, we can see that the learning rate changed, which indicates our optimizer has been applied correctly.
# If seen, we can also look at the TBE logs of the cell to see that our new optimizer is indeed being applied
print(f"Original Sharded EBC fused optimizer: {sharded_ebc.fused_optimizer}")
print(f"Sharded EBC with fused parameters fused optimizer: {sharded_ebc_fused_params.fused_optimizer}")
print(f"Type of optimizer: {type(sharded_ebc_fused_params.fused_optimizer)}")
from torch.distributed.optim import _apply_optimizer_in_backward as apply_optimizer_in_backward
import copy
# Option 2: Applying optimizer through apply_optimizer_in_backward
# Note: we need to call apply_optimizer_in_backward on unsharded model first and then shard it
# We can achieve the same result as we did in the previous
ebc_apply_opt = copy.deepcopy(ebc)
optimizer_kwargs = {"lr": 0.5}
for name, param in ebc_apply_opt.named_parameters():
print(f"{name=}")
apply_optimizer_in_backward(torch.optim.SGD, [param], optimizer_kwargs)
sharded_ebc_apply_opt = sharder.shard(ebc_apply_opt, plan.plan[""], env, torch.device("cuda"))
# Now when we print the optimizer, we will see our new learning rate, you can verify momentum through the TBE logs as well if outputted
print(sharded_ebc_apply_opt.fused_optimizer)
print(type(sharded_ebc_apply_opt.fused_optimizer))
# We can also check through the filter other parameters that aren't associated with the "fused" optimizer(s)
# Practically, just non TorchRec module parameters. Since our module is just a TorchRec EBC
# there are no other parameters that aren't associated with TorchRec
print("Non Fused Model Parameters:")
print(dict(in_backward_optimizer_filter(sharded_ebc_fused_params.named_parameters())).keys())
# Here we do a dummy backwards call and see that parameter updates for fused
# optimizers happen as a result of the backward pass
ebc_output = sharded_ebc_fused_params(kjt).wait().values()
loss = torch.sum(torch.ones_like(ebc_output) - ebc_output)
print(f"First Iteration Loss: {loss}")
loss.backward()
ebc_output = sharded_ebc_fused_params(kjt).wait().values()
loss = torch.sum(torch.ones_like(ebc_output) - ebc_output)
# We don't call an optimizer.step(), so for the loss to have changed here,
# that means that the gradients were somehow updated, which is what the
# fused optimizer automatically handles for us
print(f"Second Iteration Loss: {loss}")
Original Sharded EBC fused optimizer: : EmbeddingFusedOptimizer (
Parameter Group 0
lr: 0.01
)
Sharded EBC with fused parameters fused optimizer: : EmbeddingFusedOptimizer (
Parameter Group 0
lr: 0.02
)
Type of optimizer: <class 'torchrec.optim.keyed.CombinedOptimizer'>
/var/lib/workspace/intermediate_source/torchrec_intro_tutorial.py:876: DeprecationWarning:
`TorchScript` support for functional optimizers is deprecated and will be removed in a future PyTorch release. Consider using the `torch.compile` optimizer instead.
name='embedding_bags.product_table.weight'
name='embedding_bags.user_table.weight'
: EmbeddingFusedOptimizer (
Parameter Group 0
lr: 0.5
)
<class 'torchrec.optim.keyed.CombinedOptimizer'>
Non Fused Model Parameters:
dict_keys([])
First Iteration Loss: 255.66006469726562
Second Iteration Loss: 245.43795776367188
推理¶
现在我们能够训练分布式嵌入,那么如何获取训练好的模型并将其优化以用于推理呢?推理通常对**模型的性能和大小**非常敏感。仅在 Python 环境中运行训练好的模型效率极低。推理环境和训练环境之间有两个主要区别。
**量化**:推理模型通常会被量化,其中模型参数会损失精度以降低预测的延迟并减小模型大小。例如,训练模型中的 FP32(4 字节)到每个嵌入权重的 INT8(1 字节)。鉴于嵌入表的规模巨大,这也是必要的,因为我们希望在推理中使用尽可能少的设备来最大程度地减少延迟。
**C++ 环境**:推理延迟非常重要,因此为了确保足够的性能,模型通常在 C++ 环境中运行,以及在没有 Python 运行时的情况下(例如在设备上)运行。
TorchRec 提供了用于将 TorchRec 模型转换为推理就绪的原语,包括:
用于量化模型的 API,使用 FBGEMM TBE 自动引入优化
用于分布式推理的嵌入分片
将模型编译为 TorchScript(与 C++ 兼容)
在本节中,我们将介绍整个工作流程:
量化模型
分片量化模型
将分片量化模型编译为 TorchScript
ebc
class InferenceModule(torch.nn.Module):
def __init__(self, ebc: torchrec.EmbeddingBagCollection):
super().__init__()
self.ebc_ = ebc
def forward(self, kjt: KeyedJaggedTensor):
return self.ebc_(kjt)
module = InferenceModule(ebc)
for name, param in module.named_parameters():
# Here, the parameters should still be FP32, as we are using a standard EBC
# FP32 is default, regularly used for training
print(name, param.shape, param.dtype)
ebc_.embedding_bags.product_table.weight torch.Size([4096, 64]) torch.float32
ebc_.embedding_bags.user_table.weight torch.Size([4096, 64]) torch.float32
量化¶
如上所示,普通的 EBC 将嵌入表权重作为 FP32 精度(每个权重 32 位)。在这里,我们将使用 TorchRec 推理库将模型的嵌入权重量化为 INT8。
from torch import quantization as quant
from torchrec.modules.embedding_configs import QuantConfig
from torchrec.quant.embedding_modules import (
EmbeddingBagCollection as QuantEmbeddingBagCollection,
)
quant_dtype = torch.int8
qconfig = QuantConfig(
# dtype of the result of the embedding lookup, post activation
# torch.float generally for compatibility with rest of the model
# as rest of the model here usually isn't quantized
activation=quant.PlaceholderObserver.with_args(dtype=torch.float),
# quantized type for embedding weights, aka parameters to actually quantize
weight=quant.PlaceholderObserver.with_args(dtype=quant_dtype),
)
qconfig_spec = {
# Map of module type to qconfig
torchrec.EmbeddingBagCollection: qconfig,
}
mapping = {
# Map of module type to quantized module type
torchrec.EmbeddingBagCollection: QuantEmbeddingBagCollection,
}
module = InferenceModule(ebc)
# Quantize the module
qebc = quant.quantize_dynamic(
module,
qconfig_spec=qconfig_spec,
mapping=mapping,
inplace=False,
)
print(f"Quantized EBC: {qebc}")
kjt = kjt.to("cpu")
qebc(kjt)
# Once quantized, goes from parameters -> buffers, as no longer trainable
for name, buffer in qebc.named_buffers():
# The shapes of the tables should be the same but the dtype should be int8 now
# post quantization
print(name, buffer.shape, buffer.dtype)
Quantized EBC: InferenceModule(
(ebc_): QuantizedEmbeddingBagCollection(
(_kjt_to_jt_dict): ComputeKJTToJTDict()
(embedding_bags): ModuleDict(
(product_table): Module()
(user_table): Module()
)
)
)
ebc_.embedding_bags.product_table.weight torch.Size([4096, 80]) torch.uint8
ebc_.embedding_bags.user_table.weight torch.Size([4096, 80]) torch.uint8
分片¶
在这里,我们对 TorchRec 量化模型进行分片。这是为了确保我们通过 FBGEMM TBE 使用高性能模块。在这里,我们使用一个设备以与训练保持一致(1 个 TBE)。
from torchrec import distributed as trec_dist
from torchrec.distributed.shard import _shard_modules
sharded_qebc = _shard_modules(
module=qebc,
device=torch.device("cpu"),
env=trec_dist.ShardingEnv.from_local(
1,
0,
),
)
print(f"Sharded Quantized EBC: {sharded_qebc}")
sharded_qebc(kjt)
WARNING:root:Could not determine LOCAL_WORLD_SIZE from environment, falling back to WORLD_SIZE.
Sharded Quantized EBC: InferenceModule(
(ebc_): ShardedQuantEmbeddingBagCollection(
(lookups):
InferGroupedPooledEmbeddingsLookup()
(_output_dists): ModuleList()
(embedding_bags): ModuleDict(
(product_table): Module()
(user_table): Module()
)
(_input_dist_module): ShardedQuantEbcInputDist()
)
)
<torchrec.sparse.jagged_tensor.KeyedTensor object at 0x7efa0e5dc850>
编译¶
现在我们有了优化的 TorchRec 推理模型。下一步是确保此模型可以在 C++ 中加载,因为它目前只能在 Python 运行时中运行。
Meta 推荐的编译方法是两方面的:torch.fx 追踪(生成模型的中间表示)并将结果转换为 TorchScript,其中 TorchScript 与 C++ 兼容。
from torchrec.fx import Tracer
tracer = Tracer(leaf_modules=["IntNBitTableBatchedEmbeddingBagsCodegen"])
graph = tracer.trace(sharded_qebc)
gm = torch.fx.GraphModule(sharded_qebc, graph)
print("Graph Module Created!")
print(gm.code)
scripted_gm = torch.jit.script(gm)
print("Scripted Graph Module Created!")
print(scripted_gm.code)
Graph Module Created!
torch.fx._symbolic_trace.wrap("torchrec_distributed_quant_embeddingbag_flatten_feature_lengths")
torch.fx._symbolic_trace.wrap("torchrec_fx_utils__fx_marker")
torch.fx._symbolic_trace.wrap("torchrec_distributed_quant_embedding_kernel__unwrap_kjt")
torch.fx._symbolic_trace.wrap("torchrec_distributed_embedding_lookup_embeddings_cat_empty_rank_handle_inference")
def forward(self, kjt : torchrec_sparse_jagged_tensor_KeyedJaggedTensor):
flatten_feature_lengths = torchrec_distributed_quant_embeddingbag_flatten_feature_lengths(kjt); kjt = None
_fx_marker = torchrec_fx_utils__fx_marker('KJT_ONE_TO_ALL_FORWARD_BEGIN', flatten_feature_lengths); _fx_marker = None
split = flatten_feature_lengths.split([2])
getitem = split[0]; split = None
to = getitem.to(device(type='cuda', index=0), non_blocking = True); getitem = None
_fx_marker_1 = torchrec_fx_utils__fx_marker('KJT_ONE_TO_ALL_FORWARD_END', flatten_feature_lengths); flatten_feature_lengths = _fx_marker_1 = None
_unwrap_kjt = torchrec_distributed_quant_embedding_kernel__unwrap_kjt(to); to = None
getitem_1 = _unwrap_kjt[0]
getitem_2 = _unwrap_kjt[1]
getitem_3 = _unwrap_kjt[2]; _unwrap_kjt = getitem_3 = None
_tensor_constant0 = self._tensor_constant0
_tensor_constant1 = self._tensor_constant1
bounds_check_indices = torch.ops.fbgemm.bounds_check_indices(_tensor_constant0, getitem_1, getitem_2, 1, _tensor_constant1, None); _tensor_constant0 = _tensor_constant1 = bounds_check_indices = None
_tensor_constant2 = self._tensor_constant2
_tensor_constant3 = self._tensor_constant3
_tensor_constant4 = self._tensor_constant4
_tensor_constant5 = self._tensor_constant5
_tensor_constant6 = self._tensor_constant6
_tensor_constant7 = self._tensor_constant7
_tensor_constant8 = self._tensor_constant8
_tensor_constant9 = self._tensor_constant9
int_nbit_split_embedding_codegen_lookup_function = torch.ops.fbgemm.int_nbit_split_embedding_codegen_lookup_function(dev_weights = _tensor_constant2, uvm_weights = _tensor_constant3, weights_placements = _tensor_constant4, weights_offsets = _tensor_constant5, weights_tys = _tensor_constant6, D_offsets = _tensor_constant7, total_D = 128, max_int2_D = 0, max_int4_D = 0, max_int8_D = 64, max_float16_D = 0, max_float32_D = 0, indices = getitem_1, offsets = getitem_2, pooling_mode = 0, indice_weights = None, output_dtype = 0, lxu_cache_weights = _tensor_constant8, lxu_cache_locations = _tensor_constant9, row_alignment = 16, max_float8_D = 0, fp8_exponent_bits = -1, fp8_exponent_bias = -1); _tensor_constant2 = _tensor_constant3 = _tensor_constant4 = _tensor_constant5 = _tensor_constant6 = _tensor_constant7 = getitem_1 = getitem_2 = _tensor_constant8 = _tensor_constant9 = None
embeddings_cat_empty_rank_handle_inference = torchrec_distributed_embedding_lookup_embeddings_cat_empty_rank_handle_inference([int_nbit_split_embedding_codegen_lookup_function], dim = 1, device = 'cuda:0', dtype = torch.float32); int_nbit_split_embedding_codegen_lookup_function = None
to_1 = embeddings_cat_empty_rank_handle_inference.to(device(type='cpu')); embeddings_cat_empty_rank_handle_inference = None
keyed_tensor = torchrec_sparse_jagged_tensor_KeyedTensor(keys = ['product', 'user'], length_per_key = [64, 64], values = to_1, key_dim = 1); to_1 = None
return keyed_tensor
/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/jit/_check.py:178: UserWarning:
The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
Scripted Graph Module Created!
def forward(self,
kjt: __torch__.torchrec.sparse.jagged_tensor.KeyedJaggedTensor) -> __torch__.torchrec.sparse.jagged_tensor.KeyedTensor:
_0 = __torch__.torchrec.distributed.quant_embeddingbag.flatten_feature_lengths
_1 = __torch__.torchrec.fx.utils._fx_marker
_2 = __torch__.torchrec.distributed.quant_embedding_kernel._unwrap_kjt
_3 = __torch__.torchrec.distributed.embedding_lookup.embeddings_cat_empty_rank_handle_inference
flatten_feature_lengths = _0(kjt, )
_fx_marker = _1("KJT_ONE_TO_ALL_FORWARD_BEGIN", flatten_feature_lengths, )
split = (flatten_feature_lengths).split([2], )
getitem = split[0]
to = (getitem).to(torch.device("cuda", 0), True, None, )
_fx_marker_1 = _1("KJT_ONE_TO_ALL_FORWARD_END", flatten_feature_lengths, )
_unwrap_kjt = _2(to, )
getitem_1 = (_unwrap_kjt)[0]
getitem_2 = (_unwrap_kjt)[1]
_tensor_constant0 = self._tensor_constant0
_tensor_constant1 = self._tensor_constant1
ops.fbgemm.bounds_check_indices(_tensor_constant0, getitem_1, getitem_2, 1, _tensor_constant1)
_tensor_constant2 = self._tensor_constant2
_tensor_constant3 = self._tensor_constant3
_tensor_constant4 = self._tensor_constant4
_tensor_constant5 = self._tensor_constant5
_tensor_constant6 = self._tensor_constant6
_tensor_constant7 = self._tensor_constant7
_tensor_constant8 = self._tensor_constant8
_tensor_constant9 = self._tensor_constant9
int_nbit_split_embedding_codegen_lookup_function = ops.fbgemm.int_nbit_split_embedding_codegen_lookup_function(_tensor_constant2, _tensor_constant3, _tensor_constant4, _tensor_constant5, _tensor_constant6, _tensor_constant7, 128, 0, 0, 64, 0, 0, getitem_1, getitem_2, 0, None, 0, _tensor_constant8, _tensor_constant9, 16)
_4 = [int_nbit_split_embedding_codegen_lookup_function]
embeddings_cat_empty_rank_handle_inference = _3(_4, 1, "cuda:0", 6, )
to_1 = torch.to(embeddings_cat_empty_rank_handle_inference, torch.device("cpu"))
_5 = ["product", "user"]
_6 = [64, 64]
keyed_tensor = __torch__.torchrec.sparse.jagged_tensor.KeyedTensor.__new__(__torch__.torchrec.sparse.jagged_tensor.KeyedTensor)
_7 = (keyed_tensor).__init__(_5, _6, to_1, 1, None, None, )
return keyed_tensor
结论¶
在本教程中,您已经从训练分布式 RecSys 模型到使其推理就绪进行了完整的学习。 TorchRec 仓库 有一个完整的示例,说明如何将 TorchRec TorchScript 模型加载到 C++ 中进行推理。
有关更多信息,请参阅我们的 dlrm 示例,其中包括使用本文档中描述的方法在 Criteo 1TB 数据集上进行多节点训练。Deep Learning Recommendation Model for Personalization and Recommendation Systems。
脚本的总运行时间:(0 分钟 0.826 秒)