注意
点击此处下载完整的示例代码
TorchRec 简介¶
创建于:2024 年 10 月 2 日 | 最后更新:2024 年 10 月 10 日 | 最后验证:2024 年 10 月 2 日
TorchRec 是一个 PyTorch 库,专为使用嵌入构建可扩展且高效的推荐系统而定制。本教程将指导您完成安装过程,介绍嵌入的概念,并强调其在推荐系统中的重要性。它提供了关于使用 PyTorch 和 TorchRec 实现嵌入的实践演示,重点介绍了通过分布式训练和高级优化来处理大型嵌入表。
嵌入的基础知识及其在推荐系统中的作用
如何在 PyTorch 环境中设置 TorchRec 以管理和实现嵌入
探索跨多个 GPU 分布大型嵌入表的高级技术
PyTorch v2.5 或更高版本,搭配 CUDA 11.8 或更高版本
Python 3.9 或更高版本
安装依赖项¶
在 Google Colab 或其他环境中运行本教程之前,请安装以下依赖项
!pip3 install --pre torch --index-url https://download.pytorch.org/whl/cu121 -U
!pip3 install fbgemm_gpu --index-url https://download.pytorch.org/whl/cu121
!pip3 install torchmetrics==1.0.3
!pip3 install torchrec --index-url https://download.pytorch.org/whl/cu121
注意
如果您在 Google Colab 中运行此教程,请确保切换到 GPU 运行时类型。有关更多信息,请参阅启用 CUDA
嵌入¶
在构建推荐系统时,类别特征通常具有巨大的基数,例如帖子、用户、广告等等。
为了表示这些实体并建模这些关系,使用了嵌入。在机器学习中,嵌入是高维空间中的实数向量,用于表示复杂数据(如单词、图像或用户)的含义。
RecSys 中的嵌入¶
现在您可能想知道,这些嵌入最初是如何生成的?嗯,嵌入表示为嵌入表中的各个行,也称为嵌入权重。这样做的原因是,嵌入或嵌入表权重像模型的所有其他权重一样,通过梯度下降进行训练!
嵌入表只是一个用于存储嵌入的大型矩阵,具有两个维度 (B, N),其中
B 是表存储的嵌入数量
N 是每个嵌入的维度数(N 维嵌入)。
嵌入表的输入表示嵌入查找,用于检索特定索引或行的嵌入。在推荐系统中,例如许多大型系统中使用的那些,唯一的 ID 不仅用于特定用户,还用于帖子和广告等实体,作为各自嵌入表的查找索引!
嵌入通过以下过程在 RecSys 中进行训练
输入/查找索引作为唯一 ID 被馈送到模型中。ID 被哈希到嵌入表的总大小,以防止 ID > 行数时出现问题
然后检索嵌入并进行池化,例如取嵌入的总和或平均值。这是必需的,因为每个示例可能存在可变数量的嵌入,而模型期望一致的形状。
嵌入与模型的其余部分结合使用,以产生预测,例如广告的点击率 (CTR)。
损失是根据示例的预测和标签计算出来的,并且模型的所有权重都通过梯度下降和反向传播进行更新,包括与该示例关联的嵌入权重。
这些嵌入对于表示类别特征(如用户、帖子和广告)至关重要,以便捕获关系并做出良好的推荐。深度学习推荐模型 (DLRM) 论文更详细地介绍了在 RecSys 中使用嵌入表的技术细节。
本教程介绍了嵌入的概念,展示了 TorchRec 特定的模块和数据类型,并描述了分布式训练如何与 TorchRec 一起工作。
import torch
PyTorch 中的嵌入¶
在 PyTorch 中,我们有以下类型的嵌入
torch.nn.Embedding
:一个嵌入表,其中前向传递按原样返回嵌入本身。torch.nn.EmbeddingBag
:嵌入表,其中前向传递返回然后进行池化的嵌入,例如,总和或平均值,也称为池化嵌入。
在本节中,我们将简要介绍通过将索引传递到表中来执行嵌入查找。
num_embeddings, embedding_dim = 10, 4
# Initialize our embedding table
weights = torch.rand(num_embeddings, embedding_dim)
print("Weights:", weights)
# Pass in pre-generated weights just for example, typically weights are randomly initialized
embedding_collection = torch.nn.Embedding(
num_embeddings, embedding_dim, _weight=weights
)
embedding_bag_collection = torch.nn.EmbeddingBag(
num_embeddings, embedding_dim, _weight=weights
)
# Print out the tables, we should see the same weights as above
print("Embedding Collection Table: ", embedding_collection.weight)
print("Embedding Bag Collection Table: ", embedding_bag_collection.weight)
# Lookup rows (ids for embedding ids) from the embedding tables
# 2D tensor with shape (batch_size, ids for each batch)
ids = torch.tensor([[1, 3]])
print("Input row IDS: ", ids)
embeddings = embedding_collection(ids)
# Print out the embedding lookups
# You should see the specific embeddings be the same as the rows (ids) of the embedding tables above
print("Embedding Collection Results: ")
print(embeddings)
print("Shape: ", embeddings.shape)
# ``nn.EmbeddingBag`` default pooling is mean, so should be mean of batch dimension of values above
pooled_embeddings = embedding_bag_collection(ids)
print("Embedding Bag Collection Results: ")
print(pooled_embeddings)
print("Shape: ", pooled_embeddings.shape)
# ``nn.EmbeddingBag`` is the same as ``nn.Embedding`` but just with pooling (mean, sum, and so on)
# We can see that the mean of the embeddings of embedding_collection is the same as the output of the embedding_bag_collection
print("Mean: ", torch.mean(embedding_collection(ids), dim=1))
Weights: tensor([[0.8823, 0.9150, 0.3829, 0.9593],
[0.3904, 0.6009, 0.2566, 0.7936],
[0.9408, 0.1332, 0.9346, 0.5936],
[0.8694, 0.5677, 0.7411, 0.4294],
[0.8854, 0.5739, 0.2666, 0.6274],
[0.2696, 0.4414, 0.2969, 0.8317],
[0.1053, 0.2695, 0.3588, 0.1994],
[0.5472, 0.0062, 0.9516, 0.0753],
[0.8860, 0.5832, 0.3376, 0.8090],
[0.5779, 0.9040, 0.5547, 0.3423]])
Embedding Collection Table: Parameter containing:
tensor([[0.8823, 0.9150, 0.3829, 0.9593],
[0.3904, 0.6009, 0.2566, 0.7936],
[0.9408, 0.1332, 0.9346, 0.5936],
[0.8694, 0.5677, 0.7411, 0.4294],
[0.8854, 0.5739, 0.2666, 0.6274],
[0.2696, 0.4414, 0.2969, 0.8317],
[0.1053, 0.2695, 0.3588, 0.1994],
[0.5472, 0.0062, 0.9516, 0.0753],
[0.8860, 0.5832, 0.3376, 0.8090],
[0.5779, 0.9040, 0.5547, 0.3423]], requires_grad=True)
Embedding Bag Collection Table: Parameter containing:
tensor([[0.8823, 0.9150, 0.3829, 0.9593],
[0.3904, 0.6009, 0.2566, 0.7936],
[0.9408, 0.1332, 0.9346, 0.5936],
[0.8694, 0.5677, 0.7411, 0.4294],
[0.8854, 0.5739, 0.2666, 0.6274],
[0.2696, 0.4414, 0.2969, 0.8317],
[0.1053, 0.2695, 0.3588, 0.1994],
[0.5472, 0.0062, 0.9516, 0.0753],
[0.8860, 0.5832, 0.3376, 0.8090],
[0.5779, 0.9040, 0.5547, 0.3423]], requires_grad=True)
Input row IDS: tensor([[1, 3]])
Embedding Collection Results:
tensor([[[0.3904, 0.6009, 0.2566, 0.7936],
[0.8694, 0.5677, 0.7411, 0.4294]]], grad_fn=<EmbeddingBackward0>)
Shape: torch.Size([1, 2, 4])
Embedding Bag Collection Results:
tensor([[0.6299, 0.5843, 0.4988, 0.6115]], grad_fn=<EmbeddingBagBackward0>)
Shape: torch.Size([1, 4])
Mean: tensor([[0.6299, 0.5843, 0.4988, 0.6115]], grad_fn=<MeanBackward1>)
恭喜!现在您已经基本了解了如何使用嵌入表——现代推荐系统的基础之一!这些表表示实体及其关系。例如,给定用户与其点赞的页面和帖子之间的关系。
TorchRec 功能概述¶
在上面的部分中,我们学习了如何使用嵌入表,这是现代推荐系统的基础之一!这些表表示实体和关系,例如用户、页面、帖子等。鉴于这些实体总是在增加,通常会应用 hash 函数来确保 ID 在某个嵌入表的范围内。然而,为了表示大量的实体并减少哈希冲突,这些表可能会变得非常庞大(例如,想想广告的数量)。事实上,这些表可能会变得非常庞大,以至于即使有 80G 内存也无法容纳在 1 个 GPU 上。
为了使用大型嵌入表训练模型,需要跨 GPU 分片这些表,这随后引入了一系列新的并行化和优化问题和机会。幸运的是,我们有 TorchRec 库,它已经遇到、整合并解决了许多这些问题。TorchRec 作为一个库,为大规模分布式嵌入提供原语。
接下来,我们将探索 TorchRec 库的主要功能。我们将从 torch.nn.Embedding
开始,并将其扩展到自定义 TorchRec 模块,探索分布式训练环境,生成嵌入的分片计划,查看 TorchRec 固有的优化,并扩展模型以准备好在 C++ 中进行推理。以下是本节内容的快速概述
TorchRec 模块和数据类型
分布式训练、分片和优化
推理
让我们从导入 TorchRec 开始
import torchrec
本节介绍 TorchRec 模块和数据类型,包括 EmbeddingCollection
和 EmbeddingBagCollection
、JaggedTensor
、KeyedJaggedTensor
、KeyedTensor
等实体。
从 EmbeddingBag
到 EmbeddingBagCollection
¶
我们已经探索了 torch.nn.Embedding
和 torch.nn.EmbeddingBag
。TorchRec 通过创建嵌入集合来扩展这些模块,换句话说,具有多个嵌入表的模块,使用 EmbeddingCollection
和 EmbeddingBagCollection
。我们将使用 EmbeddingBagCollection
来表示一组 embedding bag。
在下面的示例代码中,我们创建了一个具有两个 embedding bag 的 EmbeddingBagCollection
(EBC),其中 1 个表示产品,1 个表示用户。每个表 product_table
和 user_table
都由大小为 4096 的 64 维嵌入表示。
ebc = torchrec.EmbeddingBagCollection(
device="cpu",
tables=[
torchrec.EmbeddingBagConfig(
name="product_table",
embedding_dim=64,
num_embeddings=4096,
feature_names=["product"],
pooling=torchrec.PoolingType.SUM,
),
torchrec.EmbeddingBagConfig(
name="user_table",
embedding_dim=64,
num_embeddings=4096,
feature_names=["user"],
pooling=torchrec.PoolingType.SUM,
)
]
)
print(ebc.embedding_bags)
ModuleDict(
(product_table): EmbeddingBag(4096, 64, mode='sum')
(user_table): EmbeddingBag(4096, 64, mode='sum')
)
让我们检查 EmbeddingBagCollection
的前向方法以及模块的输入和输出
import inspect
# Let's look at the ``EmbeddingBagCollection`` forward method
# What is a ``KeyedJaggedTensor`` and ``KeyedTensor``?
print(inspect.getsource(ebc.forward))
def forward(self, features: KeyedJaggedTensor) -> KeyedTensor:
"""
Run the EmbeddingBagCollection forward pass. This method takes in a `KeyedJaggedTensor`
and returns a `KeyedTensor`, which is the result of pooling the embeddings for each feature.
Args:
features (KeyedJaggedTensor): Input KJT
Returns:
KeyedTensor
"""
flat_feature_names: List[str] = []
for names in self._feature_names:
flat_feature_names.extend(names)
inverse_indices = reorder_inverse_indices(
inverse_indices=features.inverse_indices_or_none(),
feature_names=flat_feature_names,
)
pooled_embeddings: List[torch.Tensor] = []
feature_dict = features.to_dict()
for i, embedding_bag in enumerate(self.embedding_bags.values()):
for feature_name in self._feature_names[i]:
f = feature_dict[feature_name]
res = embedding_bag(
input=f.values(),
offsets=f.offsets(),
per_sample_weights=f.weights() if self._is_weighted else None,
).float()
pooled_embeddings.append(res)
return KeyedTensor(
keys=self._embedding_names,
values=process_pooled_embeddings(
pooled_embeddings=pooled_embeddings,
inverse_indices=inverse_indices,
),
length_per_key=self._lengths_per_embedding,
)
TorchRec 输入/输出数据类型¶
TorchRec 对其模块的输入和输出具有不同的数据类型:JaggedTensor
、KeyedJaggedTensor
和 KeyedTensor
。现在您可能会问,为什么要创建新的数据类型来表示稀疏特征?要回答这个问题,我们必须了解稀疏特征在代码中是如何表示的。
稀疏特征也称为 id_list_feature
和 id_score_list_feature
,并且是将被用作嵌入表索引以检索该 ID 的嵌入的 ID。举一个非常简单的例子,想象一下一个用户交互过的广告的单个稀疏特征。输入本身将是用户交互过的一组广告 ID,检索到的嵌入将是这些广告的语义表示。在代码中表示这些特征的棘手之处在于,在每个输入示例中,ID 的数量是可变的。用户某一天可能只与一个广告互动,而第二天他们可能会与三个广告互动。
下面显示了一个简单的表示,其中我们有一个 lengths
张量,表示一个批次中一个示例中有多少个索引,以及一个包含索引本身的 values
张量。
# Batch Size 2
# 1 ID in example 1, 2 IDs in example 2
id_list_feature_lengths = torch.tensor([1, 2])
# Values (IDs) tensor: ID 5 is in example 1, ID 7, 1 is in example 2
id_list_feature_values = torch.tensor([5, 7, 1])
接下来,让我们看一下偏移量以及每个批次中包含的内容
# Lengths can be converted to offsets for easy indexing of values
id_list_feature_offsets = torch.cumsum(id_list_feature_lengths, dim=0)
print("Offsets: ", id_list_feature_offsets)
print("First Batch: ", id_list_feature_values[: id_list_feature_offsets[0]])
print(
"Second Batch: ",
id_list_feature_values[id_list_feature_offsets[0] : id_list_feature_offsets[1]],
)
from torchrec import JaggedTensor
# ``JaggedTensor`` is just a wrapper around lengths/offsets and values tensors!
jt = JaggedTensor(values=id_list_feature_values, lengths=id_list_feature_lengths)
# Automatically compute offsets from lengths
print("Offsets: ", jt.offsets())
# Convert to list of values
print("List of Values: ", jt.to_dense())
# ``__str__`` representation
print(jt)
from torchrec import KeyedJaggedTensor
# ``JaggedTensor`` represents IDs for 1 feature, but we have multiple features in an ``EmbeddingBagCollection``
# That's where ``KeyedJaggedTensor`` comes in! ``KeyedJaggedTensor`` is just multiple ``JaggedTensors`` for multiple id_list_feature_offsets
# From before, we have our two features "product" and "user". Let's create ``JaggedTensors`` for both!
product_jt = JaggedTensor(
values=torch.tensor([1, 2, 1, 5]), lengths=torch.tensor([3, 1])
)
user_jt = JaggedTensor(values=torch.tensor([2, 3, 4, 1]), lengths=torch.tensor([2, 2]))
# Q1: How many batches are there, and which values are in the first batch for ``product_jt`` and ``user_jt``?
kjt = KeyedJaggedTensor.from_jt_dict({"product": product_jt, "user": user_jt})
# Look at our feature keys for the ``KeyedJaggedTensor``
print("Keys: ", kjt.keys())
# Look at the overall lengths for the ``KeyedJaggedTensor``
print("Lengths: ", kjt.lengths())
# Look at all values for ``KeyedJaggedTensor``
print("Values: ", kjt.values())
# Can convert ``KeyedJaggedTensor`` to dictionary representation
print("to_dict: ", kjt.to_dict())
# ``KeyedJaggedTensor`` string representation
print(kjt)
# Q2: What are the offsets for the ``KeyedJaggedTensor``?
# Now we can run a forward pass on our ``EmbeddingBagCollection`` from before
result = ebc(kjt)
result
# Result is a ``KeyedTensor``, which contains a list of the feature names and the embedding results
print(result.keys())
# The results shape is [2, 128], as batch size of 2. Reread previous section if you need a refresher on how the batch size is determined
# 128 for dimension of embedding. If you look at where we initialized the ``EmbeddingBagCollection``, we have two tables "product" and "user" of dimension 64 each
# meaning embeddings for both features are of size 64. 64 + 64 = 128
print(result.values().shape)
# Nice to_dict method to determine the embeddings that belong to each feature
result_dict = result.to_dict()
for key, embedding in result_dict.items():
print(key, embedding.shape)
Offsets: tensor([1, 3])
First Batch: tensor([5])
Second Batch: tensor([7, 1])
Offsets: tensor([0, 1, 3])
List of Values: [tensor([5]), tensor([7, 1])]
JaggedTensor({
[[5], [7, 1]]
})
Keys: ['product', 'user']
Lengths: tensor([3, 1, 2, 2])
Values: tensor([1, 2, 1, 5, 2, 3, 4, 1])
to_dict: {'product': <torchrec.sparse.jagged_tensor.JaggedTensor object at 0x7f377446ee00>, 'user': <torchrec.sparse.jagged_tensor.JaggedTensor object at 0x7f36958db5e0>}
KeyedJaggedTensor({
"product": [[1, 2, 1], [5]],
"user": [[2, 3], [4, 1]]
})
['product', 'user']
torch.Size([2, 128])
product torch.Size([2, 64])
user torch.Size([2, 64])
恭喜!您现在了解了 TorchRec 模块和数据类型。为自己能走到这一步而感到自豪。接下来,我们将学习分布式训练和分片。
分布式训练和分片¶
现在我们已经掌握了 TorchRec 模块和数据类型,是时候将其提升到一个新的水平了。
请记住,TorchRec 的主要目的是为分布式嵌入提供原语。到目前为止,我们只在单个设备上使用嵌入表。鉴于嵌入表有多小,这已经成为可能,但在生产环境中,情况通常并非如此。嵌入表通常变得非常庞大,以至于一个表无法容纳在单个 GPU 上,从而产生了对多个设备和分布式环境的需求。
在本节中,我们将探索如何设置分布式环境,实际的生产训练是如何完成的,以及探索分片嵌入表,所有这些都使用 TorchRec 完成。
本节也仅使用 1 个 GPU,尽管它将以分布式方式处理。这仅是训练的限制,因为训练每个 GPU 都有一个进程。推理没有遇到此要求
在下面的示例代码中,我们设置了 PyTorch 分布式环境。
警告
如果您在 Google Colab 中运行此代码,则只能调用此单元格一次,再次调用它将导致错误,因为您只能初始化进程组一次。
import os
import torch.distributed as dist
# Set up environment variables for distributed training
# RANK is which GPU we are on, default 0
os.environ["RANK"] = "0"
# How many devices in our "world", colab notebook can only handle 1 process
os.environ["WORLD_SIZE"] = "1"
# Localhost as we are training locally
os.environ["MASTER_ADDR"] = "localhost"
# Port for distributed training
os.environ["MASTER_PORT"] = "29500"
# nccl backend is for GPUs, gloo is for CPUs
dist.init_process_group(backend="gloo")
print(f"Distributed environment initialized: {dist}")
Distributed environment initialized: <module 'torch.distributed' from '/usr/local/lib/python3.10/dist-packages/torch/distributed/__init__.py'>
分布式嵌入¶
我们已经使用过主要的 TorchRec 模块:EmbeddingBagCollection
。我们已经检查了它的工作原理以及数据如何在 TorchRec 中表示。但是,我们尚未探索 TorchRec 的主要部分之一,即分布式嵌入。
到目前为止,GPU 是 ML 工作负载最受欢迎的选择,因为它们能够比 CPU 执行更多数量级的浮点运算/秒(FLOPs)。但是,GPU 受限于稀缺的快速内存(HBM,类似于 CPU 的 RAM),通常约为 10GB。
RecSys 模型可以包含远远超过 1 个 GPU 内存限制的嵌入表,因此需要跨多个 GPU 分布嵌入表,也称为模型并行。另一方面,数据并行是指在每个 GPU 上复制整个模型,每个 GPU 接收不同的数据批次进行训练,并在后向传递中同步梯度。
模型中需要较少计算但需要更多内存的部分(嵌入)通过模型并行分布,而模型中需要更多计算和更少内存的部分(密集层、MLP 等)通过数据并行分布。
分片¶
为了分布嵌入表,我们将嵌入表拆分成几部分,并将这些部分放置到不同的设备上,也称为“分片”。
有很多方法可以分片嵌入表。最常见的方法是
表级分片:整个表完全放置在一个设备上
列级分片:嵌入表的列被分片
行级分片:嵌入表的行被分片
分片模块¶
虽然所有这些看起来都需要处理和实现很多东西,但您很幸运。TorchRec 提供了用于轻松进行分布式训练和推理的所有原语!事实上,TorchRec 模块有两个对应的类,用于在分布式环境中处理任何 TorchRec 模块
模块分片器:此类公开了一个
shard
API,用于处理分片 TorchRec 模块,从而生成分片模块。* 对于EmbeddingBagCollection
,分片器是 EmbeddingBagCollectionSharder分片模块:此类是 TorchRec 模块的分片变体。它具有与常规 TorchRec 模块相同的输入/输出,但经过更多优化,并且可以在分布式环境中工作。* 对于
EmbeddingBagCollection
,分片变体是 ShardedEmbeddingBagCollection
每个 TorchRec 模块都有一个未分片和分片变体。
未分片版本旨在进行原型设计和实验。
分片版本旨在用于分布式环境中进行分布式训练和推理。
TorchRec 模块的分片版本,例如 EmbeddingBagCollection
,将处理模型并行所需的一切,例如 GPU 之间用于将嵌入分布到正确 GPU 的通信。
回顾我们的 EmbeddingBagCollection
模块
ebc
from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
from torchrec.distributed.types import ShardingEnv
# Corresponding sharder for ``EmbeddingBagCollection`` module
sharder = EmbeddingBagCollectionSharder()
# ``ProcessGroup`` from torch.distributed initialized 2 cells above
pg = dist.GroupMember.WORLD
assert pg is not None, "Process group is not initialized"
print(f"Process Group: {pg}")
Process Group: <torch.distributed.distributed_c10d.ProcessGroup object at 0x7f370ccd1130>
规划器¶
在我们展示分片如何工作之前,我们必须了解规划器,它可以帮助我们确定最佳分片配置。
给定多个嵌入表和多个 rank,可能存在许多不同的分片配置。例如,给定 2 个嵌入表和 2 个 GPU,您可以
在每个 GPU 上放置 1 个表
将两个表都放在单个 GPU 上,而另一个 GPU 上不放任何表
在每个 GPU 上放置某些行和列
考虑到所有这些可能性,我们通常希望获得一个对性能最佳的分片配置。
这就是规划器的作用所在。规划器能够根据嵌入表的数量和 GPU 的数量,确定最佳配置。事实证明,手动执行此操作非常困难,工程师必须考虑大量因素才能确保获得最佳分片计划。幸运的是,当使用规划器时,TorchRec 提供了自动规划器。
TorchRec 规划器
评估硬件的内存约束
根据内存提取(作为嵌入查找)估算计算量
解决特定于数据因素
考虑其他硬件特性(如带宽)以生成最佳分片计划
为了考虑到所有这些变量,TorchRec 规划器可以接收各种数量的嵌入表数据、约束、硬件信息和拓扑,以帮助为模型生成最佳分片计划,这通常在各个堆栈中提供。
要了解有关分片的更多信息,请参阅我们的分片教程。
# In our case, 1 GPU and compute on CUDA device
planner = EmbeddingShardingPlanner(
topology=Topology(
world_size=1,
compute_device="cuda",
)
)
# Run planner to get plan for sharding
plan = planner.collective_plan(ebc, [sharder], pg)
print(f"Sharding Plan generated: {plan}")
Sharding Plan generated: module:
param | sharding type | compute kernel | ranks
------------- | ------------- | -------------- | -----
product_table | table_wise | fused | [0]
user_table | table_wise | fused | [0]
param | shard offsets | shard sizes | placement
------------- | ------------- | ----------- | -------------
product_table | [0, 0] | [4096, 64] | rank:0/cuda:0
user_table | [0, 0] | [4096, 64] | rank:0/cuda:0
规划器结果¶
如您在上面看到的,在运行规划器时,有很多输出。我们可以看到正在计算大量统计信息,以及我们的表最终被放置在哪里。
运行规划器的结果是一个静态计划,可以重复用于分片!这允许分片对于生产模型是静态的,而不是每次都确定新的分片计划。下面,我们使用分片计划最终生成我们的 ShardedEmbeddingBagCollection
。
# The static plan that was generated
plan
env = ShardingEnv.from_process_group(pg)
# Shard the ``EmbeddingBagCollection`` module using the ``EmbeddingBagCollectionSharder``
sharded_ebc = sharder.shard(ebc, plan.plan[""], env, torch.device("cuda"))
print(f"Sharded EBC Module: {sharded_ebc}")
Sharded EBC Module: ShardedEmbeddingBagCollection(
(lookups):
GroupedPooledEmbeddingsLookup(
(_emb_modules): ModuleList(
(0): BatchedFusedEmbeddingBag(
(_emb_module): SplitTableBatchedEmbeddingBagsCodegen()
)
)
)
(_output_dists):
TwPooledEmbeddingDist()
(embedding_bags): ModuleDict(
(product_table): Module()
(user_table): Module()
)
)
使用 LazyAwaitable
进行 GPU 训练¶
请记住,TorchRec 是一个高度优化的库,用于分布式嵌入。TorchRec 引入的一个概念,旨在提高 GPU 训练的性能是 LazyAwaitable。您将看到 LazyAwaitable
类型作为各种分片 TorchRec 模块的输出。所有 LazyAwaitable
类型所做的就是尽可能延迟计算某些结果,并且它通过充当异步类型来实现这一点。
from typing import List
from torchrec.distributed.types import LazyAwaitable
# Demonstrate a ``LazyAwaitable`` type:
class ExampleAwaitable(LazyAwaitable[torch.Tensor]):
def __init__(self, size: List[int]) -> None:
super().__init__()
self._size = size
def _wait_impl(self) -> torch.Tensor:
return torch.ones(self._size)
awaitable = ExampleAwaitable([3, 2])
awaitable.wait()
kjt = kjt.to("cuda")
output = sharded_ebc(kjt)
# The output of our sharded ``EmbeddingBagCollection`` module is an `Awaitable`?
print(output)
kt = output.wait()
# Now we have our ``KeyedTensor`` after calling ``.wait()``
# If you are confused as to why we have a ``KeyedTensor ``output,
# give yourself a refresher on the unsharded ``EmbeddingBagCollection`` module
print(type(kt))
print(kt.keys())
print(kt.values().shape)
# Same output format as unsharded ``EmbeddingBagCollection``
result_dict = kt.to_dict()
for key, embedding in result_dict.items():
print(key, embedding.shape)
<torchrec.distributed.embeddingbag.EmbeddingBagCollectionAwaitable object at 0x7f36956370d0>
<class 'torchrec.sparse.jagged_tensor.KeyedTensor'>
['product', 'user']
torch.Size([2, 128])
product torch.Size([2, 64])
user torch.Size([2, 64])
分片 TorchRec 模块的剖析¶
我们现在已经成功地对 EmbeddingBagCollection
进行了分片,使用的分片计划是我们生成的!分片后的模块具有 TorchRec 中的通用 API,这些 API 抽象了多个 GPU 之间的分布式通信/计算。事实上,这些 API 针对训练和推理的性能进行了高度优化。 以下是 TorchRec 提供的三个用于分布式训练/推理的通用 API
input_dist
: 处理从 GPU 到 GPU 的输入分发。lookups
: 使用 FBGEMM TBE 以优化的、批处理的方式执行实际的嵌入查找(稍后会详细介绍)。output_dist
: 处理从 GPU 到 GPU 的输出分发。
输入和输出的分发通过 NCCL Collectives 完成,具体来说是 All-to-Alls,所有 GPU 在彼此之间发送和接收数据。TorchRec 与 PyTorch 分布式接口对接进行集合通信,并为最终用户提供简洁的抽象,消除了对底层细节的担忧。
反向传播执行所有这些集合通信,但顺序相反,用于梯度分发。input_dist
、lookup
和 output_dist
都依赖于分片方案。由于我们以表方式进行分片,因此这些 API 是由 TwPooledEmbeddingSharding 构建的模块。
sharded_ebc
# Distribute input KJTs to all other GPUs and receive KJTs
sharded_ebc._input_dists
# Distribute output embeddings to all other GPUs and receive embeddings
sharded_ebc._output_dists
[TwPooledEmbeddingDist(
(_dist): PooledEmbeddingsAllToAll()
)]
优化嵌入查找¶
在执行嵌入表的集合的查找时,一个简单的解决方案是迭代所有 nn.EmbeddingBags
并为每个表执行查找。这正是标准的、未分片的 EmbeddingBagCollection
所做的。然而,虽然这个解决方案很简单,但它非常慢。
FBGEMM 是一个提供高度优化的 GPU 算子(也称为内核)的库。其中一个算子被称为 Table Batched Embedding (TBE),它提供了两个主要的优化:
表批处理,它允许您通过一次内核调用查找多个嵌入。
优化器融合,它允许模块使用规范的 PyTorch 优化器和参数来更新自身。
ShardedEmbeddingBagCollection
使用 FBGEMM TBE 作为查找,而不是传统的 nn.EmbeddingBags
,以实现优化的嵌入查找。
sharded_ebc._lookups
[GroupedPooledEmbeddingsLookup(
(_emb_modules): ModuleList(
(0): BatchedFusedEmbeddingBag(
(_emb_module): SplitTableBatchedEmbeddingBagsCodegen()
)
)
)]
DistributedModelParallel
¶
我们现在已经探索了对单个 EmbeddingBagCollection
进行分片!我们能够使用 EmbeddingBagCollectionSharder
和未分片的 EmbeddingBagCollection
来生成 ShardedEmbeddingBagCollection
模块。这个工作流程很好,但通常在实现模型并行时,DistributedModelParallel (DMP) 被用作标准接口。当用 DMP 包装您的模型(在本例中为 ebc
)时,将发生以下情况:
决定如何对模型进行分片。DMP 将收集可用的分片器,并提出对嵌入表进行分片的最佳方法计划(例如,
EmbeddingBagCollection
)。实际对模型进行分片。这包括在适当的设备上为每个嵌入表分配内存。
DMP 接受我们刚刚实验过的所有内容,例如静态分片计划、分片器列表等。但是,它也有一些不错的默认设置,可以无缝地对 TorchRec 模型进行分片。在这个玩具示例中,由于我们有两个嵌入表和一个 GPU,TorchRec 会将它们都放在单个 GPU 上。
ebc
model = torchrec.distributed.DistributedModelParallel(ebc, device=torch.device("cuda"))
out = model(kjt)
out.wait()
model
DistributedModelParallel(
(_dmp_wrapped_module): ShardedEmbeddingBagCollection(
(lookups):
GroupedPooledEmbeddingsLookup(
(_emb_modules): ModuleList(
(0): BatchedFusedEmbeddingBag(
(_emb_module): SplitTableBatchedEmbeddingBagsCodegen()
)
)
)
(_input_dists):
TwSparseFeaturesDist(
(_dist): KJTAllToAll()
)
(_output_dists):
TwPooledEmbeddingDist(
(_dist): PooledEmbeddingsAllToAll()
)
(embedding_bags): ModuleDict(
(product_table): Module()
(user_table): Module()
)
)
)
分片最佳实践¶
目前,我们的配置仅在 1 个 GPU(或 rank)上进行分片,这很简单:只需将所有表放在 1 个 GPU 内存中即可。然而,在实际生产用例中,嵌入表通常在数百个 GPU 上进行分片,采用不同的分片方法,例如按表分片、按行分片和按列分片。确定适当的分片配置非常重要(以防止内存不足问题),同时保持内存和计算方面的平衡,以获得最佳性能。
添加优化器¶
请记住,TorchRec 模块针对大规模分布式训练进行了高度优化。一个重要的优化是关于优化器的。
TorchRec 模块提供了一个无缝的 API,可以在训练中融合反向传播和优化步骤,从而显着优化性能并减少内存使用,同时还可以在为不同的模型参数分配不同的优化器时提供粒度。
优化器类¶
TorchRec 使用 CombinedOptimizer
,其中包含 KeyedOptimizers
的集合。CombinedOptimizer
有效地简化了为模型中各种子组处理多个优化器的操作。KeyedOptimizer
扩展了 torch.optim.Optimizer
,并通过参数字典初始化,公开了这些参数。EmbeddingBagCollection
中的每个 TBE
模块都将有自己的 KeyedOptimizer
,它们组合成一个 CombinedOptimizer
。
TorchRec 中的融合优化器¶
使用 DistributedModelParallel
,优化器是融合的,这意味着优化器更新在反向传播中完成。这是 TorchRec 和 FBGEMM 中的一项优化,其中优化器嵌入梯度不会物化,而是直接应用于参数。这带来了显着的内存节省,因为嵌入梯度通常与参数本身的大小相同。
但是,您可以选择使优化器 dense
,这不会应用此优化,并允许您检查嵌入梯度或根据需要对其应用计算。在这种情况下,密集优化器将是您的 规范 PyTorch 模型训练循环和优化器。
通过 DistributedModelParallel
创建优化器后,您仍然需要为与 TorchRec 嵌入模块无关的其他参数管理优化器。要查找其他参数,请使用 in_backward_optimizer_filter(model.named_parameters())
。像对待普通的 Torch 优化器一样,将优化器应用于这些参数,并将此优化器和 model.fused_optimizer
组合到一个 CombinedOptimizer
中,您可以在训练循环中使用它来 zero_grad
和 step
。
向 EmbeddingBagCollection
添加优化器¶
我们将以两种方式执行此操作,这两种方式是等效的,但根据您的偏好为您提供选项:
通过分片器中的
fused_params
传递优化器 kwargs。通过
apply_optimizer_in_backward
,它将优化器参数转换为fused_params
以传递给EmbeddingBagCollection
或EmbeddingCollection
中的TBE
。
# Option 1: Passing optimizer kwargs through fused parameters
from torchrec.optim.optimizers import in_backward_optimizer_filter
from fbgemm_gpu.split_embedding_configs import EmbOptimType
# We initialize the sharder with
fused_params = {
"optimizer": EmbOptimType.EXACT_ROWWISE_ADAGRAD,
"learning_rate": 0.02,
"eps": 0.002,
}
# Initialize sharder with ``fused_params``
sharder_with_fused_params = EmbeddingBagCollectionSharder(fused_params=fused_params)
# We'll use same plan and unsharded EBC as before but this time with our new sharder
sharded_ebc_fused_params = sharder_with_fused_params.shard(ebc, plan.plan[""], env, torch.device("cuda"))
# Looking at the optimizer of each, we can see that the learning rate changed, which indicates our optimizer has been applied correctly.
# If seen, we can also look at the TBE logs of the cell to see that our new optimizer is indeed being applied
print(f"Original Sharded EBC fused optimizer: {sharded_ebc.fused_optimizer}")
print(f"Sharded EBC with fused parameters fused optimizer: {sharded_ebc_fused_params.fused_optimizer}")
print(f"Type of optimizer: {type(sharded_ebc_fused_params.fused_optimizer)}")
from torch.distributed.optim import _apply_optimizer_in_backward as apply_optimizer_in_backward
import copy
# Option 2: Applying optimizer through apply_optimizer_in_backward
# Note: we need to call apply_optimizer_in_backward on unsharded model first and then shard it
# We can achieve the same result as we did in the previous
ebc_apply_opt = copy.deepcopy(ebc)
optimizer_kwargs = {"lr": 0.5}
for name, param in ebc_apply_opt.named_parameters():
print(f"{name=}")
apply_optimizer_in_backward(torch.optim.SGD, [param], optimizer_kwargs)
sharded_ebc_apply_opt = sharder.shard(ebc_apply_opt, plan.plan[""], env, torch.device("cuda"))
# Now when we print the optimizer, we will see our new learning rate, you can verify momentum through the TBE logs as well if outputted
print(sharded_ebc_apply_opt.fused_optimizer)
print(type(sharded_ebc_apply_opt.fused_optimizer))
# We can also check through the filter other parameters that aren't associated with the "fused" optimizer(s)
# Practically, just non TorchRec module parameters. Since our module is just a TorchRec EBC
# there are no other parameters that aren't associated with TorchRec
print("Non Fused Model Parameters:")
print(dict(in_backward_optimizer_filter(sharded_ebc_fused_params.named_parameters())).keys())
# Here we do a dummy backwards call and see that parameter updates for fused
# optimizers happen as a result of the backward pass
ebc_output = sharded_ebc_fused_params(kjt).wait().values()
loss = torch.sum(torch.ones_like(ebc_output) - ebc_output)
print(f"First Iteration Loss: {loss}")
loss.backward()
ebc_output = sharded_ebc_fused_params(kjt).wait().values()
loss = torch.sum(torch.ones_like(ebc_output) - ebc_output)
# We don't call an optimizer.step(), so for the loss to have changed here,
# that means that the gradients were somehow updated, which is what the
# fused optimizer automatically handles for us
print(f"Second Iteration Loss: {loss}")
Original Sharded EBC fused optimizer: : EmbeddingFusedOptimizer (
Parameter Group 0
lr: 0.01
)
Sharded EBC with fused parameters fused optimizer: : EmbeddingFusedOptimizer (
Parameter Group 0
lr: 0.02
)
Type of optimizer: <class 'torchrec.optim.keyed.CombinedOptimizer'>
name='embedding_bags.product_table.weight'
name='embedding_bags.user_table.weight'
: EmbeddingFusedOptimizer (
Parameter Group 0
lr: 0.5
)
<class 'torchrec.optim.keyed.CombinedOptimizer'>
Non Fused Model Parameters:
dict_keys([])
First Iteration Loss: 255.66006469726562
Second Iteration Loss: 245.43795776367188
推理¶
既然我们能够训练分布式嵌入,我们如何才能获取训练好的模型并针对推理对其进行优化?推理通常对性能和模型大小非常敏感。仅在 Python 环境中运行训练好的模型效率非常低。推理和训练环境之间有两个关键区别:
量化:推理模型通常是量化的,其中模型参数会损失精度,以实现更低的预测延迟和更小的模型大小。例如,训练模型中的 FP32(4 字节)到每个嵌入权重的 INT8(1 字节)。鉴于嵌入表规模庞大,这也是必要的,因为我们希望尽可能少地使用设备进行推理,以最大限度地减少延迟。
C++ 环境:推理延迟非常重要,因此为了确保充足的性能,模型通常在 C++ 环境中运行,以及在没有 Python 运行时的环境中运行,例如在设备上。
TorchRec 提供了将 TorchRec 模型转换为可用于推理的基元,包括:
用于量化模型的 API,自动引入 FBGEMM TBE 优化
用于分布式推理的分片嵌入
将模型编译为 TorchScript(与 C++ 兼容)
在本节中,我们将介绍以下完整工作流程:
量化模型
分片量化模型
将分片量化模型编译为 TorchScript
ebc
class InferenceModule(torch.nn.Module):
def __init__(self, ebc: torchrec.EmbeddingBagCollection):
super().__init__()
self.ebc_ = ebc
def forward(self, kjt: KeyedJaggedTensor):
return self.ebc_(kjt)
module = InferenceModule(ebc)
for name, param in module.named_parameters():
# Here, the parameters should still be FP32, as we are using a standard EBC
# FP32 is default, regularly used for training
print(name, param.shape, param.dtype)
ebc_.embedding_bags.product_table.weight torch.Size([4096, 64]) torch.float32
ebc_.embedding_bags.user_table.weight torch.Size([4096, 64]) torch.float32
量化¶
如您在上面看到的,正常的 EBC 包含作为 FP32 精度(每个权重 32 位)的嵌入表权重。在这里,我们将使用 TorchRec 推理库将模型的嵌入权重量化为 INT8。
from torch import quantization as quant
from torchrec.modules.embedding_configs import QuantConfig
from torchrec.quant.embedding_modules import (
EmbeddingBagCollection as QuantEmbeddingBagCollection,
)
quant_dtype = torch.int8
qconfig = QuantConfig(
# dtype of the result of the embedding lookup, post activation
# torch.float generally for compatibility with rest of the model
# as rest of the model here usually isn't quantized
activation=quant.PlaceholderObserver.with_args(dtype=torch.float),
# quantized type for embedding weights, aka parameters to actually quantize
weight=quant.PlaceholderObserver.with_args(dtype=quant_dtype),
)
qconfig_spec = {
# Map of module type to qconfig
torchrec.EmbeddingBagCollection: qconfig,
}
mapping = {
# Map of module type to quantized module type
torchrec.EmbeddingBagCollection: QuantEmbeddingBagCollection,
}
module = InferenceModule(ebc)
# Quantize the module
qebc = quant.quantize_dynamic(
module,
qconfig_spec=qconfig_spec,
mapping=mapping,
inplace=False,
)
print(f"Quantized EBC: {qebc}")
kjt = kjt.to("cpu")
qebc(kjt)
# Once quantized, goes from parameters -> buffers, as no longer trainable
for name, buffer in qebc.named_buffers():
# The shapes of the tables should be the same but the dtype should be int8 now
# post quantization
print(name, buffer.shape, buffer.dtype)
Quantized EBC: InferenceModule(
(ebc_): QuantizedEmbeddingBagCollection(
(_kjt_to_jt_dict): ComputeKJTToJTDict()
(embedding_bags): ModuleDict(
(product_table): Module()
(user_table): Module()
)
)
)
ebc_.embedding_bags.product_table.weight torch.Size([4096, 80]) torch.uint8
ebc_.embedding_bags.user_table.weight torch.Size([4096, 80]) torch.uint8
分片¶
这里我们执行 TorchRec 量化模型的分片。这是为了确保我们通过 FBGEMM TBE 使用高性能模块。这里我们使用一个设备,以与训练(1 个 TBE)保持一致。
from torchrec import distributed as trec_dist
from torchrec.distributed.shard import _shard_modules
sharded_qebc = _shard_modules(
module=qebc,
device=torch.device("cpu"),
env=trec_dist.ShardingEnv.from_local(
1,
0,
),
)
print(f"Sharded Quantized EBC: {sharded_qebc}")
sharded_qebc(kjt)
Sharded Quantized EBC: InferenceModule(
(ebc_): ShardedQuantEmbeddingBagCollection(
(lookups):
InferGroupedPooledEmbeddingsLookup()
(_output_dists): ModuleList()
(embedding_bags): ModuleDict(
(product_table): Module()
(user_table): Module()
)
(_input_dist_module): ShardedQuantEbcInputDist()
)
)
<torchrec.sparse.jagged_tensor.KeyedTensor object at 0x7f3695878310>
编译¶
现在我们有了优化的 eager TorchRec 推理模型。下一步是确保此模型可以在 C++ 中加载,因为它目前只能在 Python 运行时中运行。
Meta 推荐的编译方法是双重的:torch.fx tracing(生成模型的中间表示)并将结果转换为 TorchScript,其中 TorchScript 与 C++ 兼容。
from torchrec.fx import Tracer
tracer = Tracer(leaf_modules=["IntNBitTableBatchedEmbeddingBagsCodegen"])
graph = tracer.trace(sharded_qebc)
gm = torch.fx.GraphModule(sharded_qebc, graph)
print("Graph Module Created!")
print(gm.code)
scripted_gm = torch.jit.script(gm)
print("Scripted Graph Module Created!")
print(scripted_gm.code)
Graph Module Created!
torch.fx._symbolic_trace.wrap("torchrec_distributed_quant_embeddingbag_flatten_feature_lengths")
torch.fx._symbolic_trace.wrap("torchrec_fx_utils__fx_marker")
torch.fx._symbolic_trace.wrap("torchrec_distributed_quant_embedding_kernel__unwrap_kjt")
torch.fx._symbolic_trace.wrap("fbgemm_gpu_split_table_batched_embeddings_ops_inference_inputs_to_device")
torch.fx._symbolic_trace.wrap("torchrec_distributed_embedding_lookup_embeddings_cat_empty_rank_handle_inference")
def forward(self, kjt : torchrec_sparse_jagged_tensor_KeyedJaggedTensor):
flatten_feature_lengths = torchrec_distributed_quant_embeddingbag_flatten_feature_lengths(kjt); kjt = None
_fx_marker = torchrec_fx_utils__fx_marker('KJT_ONE_TO_ALL_FORWARD_BEGIN', flatten_feature_lengths); _fx_marker = None
split = flatten_feature_lengths.split([2])
getitem = split[0]; split = None
to = getitem.to(device(type='cuda', index=0), non_blocking = True); getitem = None
_fx_marker_1 = torchrec_fx_utils__fx_marker('KJT_ONE_TO_ALL_FORWARD_END', flatten_feature_lengths); flatten_feature_lengths = _fx_marker_1 = None
_unwrap_kjt = torchrec_distributed_quant_embedding_kernel__unwrap_kjt(to); to = None
getitem_1 = _unwrap_kjt[0]
getitem_2 = _unwrap_kjt[1]
getitem_3 = _unwrap_kjt[2]; _unwrap_kjt = getitem_3 = None
inputs_to_device = fbgemm_gpu_split_table_batched_embeddings_ops_inference_inputs_to_device(getitem_1, getitem_2, None, device(type='cuda', index=0)); getitem_1 = getitem_2 = None
getitem_4 = inputs_to_device[0]
getitem_5 = inputs_to_device[1]
getitem_6 = inputs_to_device[2]; inputs_to_device = None
_tensor_constant0 = self._tensor_constant0
_tensor_constant1 = self._tensor_constant1
bounds_check_indices = torch.ops.fbgemm.bounds_check_indices(_tensor_constant0, getitem_4, getitem_5, 1, _tensor_constant1, getitem_6); _tensor_constant0 = _tensor_constant1 = bounds_check_indices = None
_tensor_constant2 = self._tensor_constant2
_tensor_constant3 = self._tensor_constant3
_tensor_constant4 = self._tensor_constant4
_tensor_constant5 = self._tensor_constant5
_tensor_constant6 = self._tensor_constant6
_tensor_constant7 = self._tensor_constant7
_tensor_constant8 = self._tensor_constant8
_tensor_constant9 = self._tensor_constant9
int_nbit_split_embedding_codegen_lookup_function = torch.ops.fbgemm.int_nbit_split_embedding_codegen_lookup_function(dev_weights = _tensor_constant2, uvm_weights = _tensor_constant3, weights_placements = _tensor_constant4, weights_offsets = _tensor_constant5, weights_tys = _tensor_constant6, D_offsets = _tensor_constant7, total_D = 128, max_int2_D = 0, max_int4_D = 0, max_int8_D = 64, max_float16_D = 0, max_float32_D = 0, indices = getitem_4, offsets = getitem_5, pooling_mode = 0, indice_weights = getitem_6, output_dtype = 0, lxu_cache_weights = _tensor_constant8, lxu_cache_locations = _tensor_constant9, row_alignment = 16, max_float8_D = 0, fp8_exponent_bits = -1, fp8_exponent_bias = -1); _tensor_constant2 = _tensor_constant3 = _tensor_constant4 = _tensor_constant5 = _tensor_constant6 = _tensor_constant7 = getitem_4 = getitem_5 = getitem_6 = _tensor_constant8 = _tensor_constant9 = None
embeddings_cat_empty_rank_handle_inference = torchrec_distributed_embedding_lookup_embeddings_cat_empty_rank_handle_inference([int_nbit_split_embedding_codegen_lookup_function], dim = 1, device = 'cuda:0', dtype = torch.float32); int_nbit_split_embedding_codegen_lookup_function = None
to_1 = embeddings_cat_empty_rank_handle_inference.to(device(type='cpu')); embeddings_cat_empty_rank_handle_inference = None
keyed_tensor = torchrec_sparse_jagged_tensor_KeyedTensor(keys = ['product', 'user'], length_per_key = [64, 64], values = to_1, key_dim = 1); to_1 = None
return keyed_tensor
/usr/local/lib/python3.10/dist-packages/torch/jit/_check.py:178: UserWarning:
The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
Scripted Graph Module Created!
def forward(self,
kjt: __torch__.torchrec.sparse.jagged_tensor.KeyedJaggedTensor) -> __torch__.torchrec.sparse.jagged_tensor.KeyedTensor:
_0 = __torch__.torchrec.distributed.quant_embeddingbag.flatten_feature_lengths
_1 = __torch__.torchrec.fx.utils._fx_marker
_2 = __torch__.torchrec.distributed.quant_embedding_kernel._unwrap_kjt
_3 = __torch__.fbgemm_gpu.split_table_batched_embeddings_ops_inference.inputs_to_device
_4 = __torch__.torchrec.distributed.embedding_lookup.embeddings_cat_empty_rank_handle_inference
flatten_feature_lengths = _0(kjt, )
_fx_marker = _1("KJT_ONE_TO_ALL_FORWARD_BEGIN", flatten_feature_lengths, )
split = (flatten_feature_lengths).split([2], )
getitem = split[0]
to = (getitem).to(torch.device("cuda", 0), True, None, )
_fx_marker_1 = _1("KJT_ONE_TO_ALL_FORWARD_END", flatten_feature_lengths, )
_unwrap_kjt = _2(to, )
getitem_1 = (_unwrap_kjt)[0]
getitem_2 = (_unwrap_kjt)[1]
inputs_to_device = _3(getitem_1, getitem_2, None, torch.device("cuda", 0), )
getitem_4 = (inputs_to_device)[0]
getitem_5 = (inputs_to_device)[1]
getitem_6 = (inputs_to_device)[2]
_tensor_constant0 = self._tensor_constant0
_tensor_constant1 = self._tensor_constant1
ops.fbgemm.bounds_check_indices(_tensor_constant0, getitem_4, getitem_5, 1, _tensor_constant1, getitem_6)
_tensor_constant2 = self._tensor_constant2
_tensor_constant3 = self._tensor_constant3
_tensor_constant4 = self._tensor_constant4
_tensor_constant5 = self._tensor_constant5
_tensor_constant6 = self._tensor_constant6
_tensor_constant7 = self._tensor_constant7
_tensor_constant8 = self._tensor_constant8
_tensor_constant9 = self._tensor_constant9
int_nbit_split_embedding_codegen_lookup_function = ops.fbgemm.int_nbit_split_embedding_codegen_lookup_function(_tensor_constant2, _tensor_constant3, _tensor_constant4, _tensor_constant5, _tensor_constant6, _tensor_constant7, 128, 0, 0, 64, 0, 0, getitem_4, getitem_5, 0, getitem_6, 0, _tensor_constant8, _tensor_constant9, 16)
_5 = [int_nbit_split_embedding_codegen_lookup_function]
embeddings_cat_empty_rank_handle_inference = _4(_5, 1, "cuda:0", 6, )
to_1 = torch.to(embeddings_cat_empty_rank_handle_inference, torch.device("cpu"))
_6 = ["product", "user"]
_7 = [64, 64]
keyed_tensor = __torch__.torchrec.sparse.jagged_tensor.KeyedTensor.__new__(__torch__.torchrec.sparse.jagged_tensor.KeyedTensor)
_8 = (keyed_tensor).__init__(_6, _7, to_1, 1, None, None, )
return keyed_tensor
结论¶
在本教程中,您已经完成了从训练分布式 RecSys 模型到使其可用于推理的整个过程。TorchRec repo 有一个完整的示例,说明如何将 TorchRec TorchScript 模型加载到 C++ 中进行推理。
有关更多信息,请参阅我们的 dlrm 示例,其中包括使用 Deep Learning Recommendation Model for Personalization and Recommendation Systems 中描述的方法在 Criteo 1TB 数据集上进行多节点训练。
脚本总运行时间: ( 0 分钟 0.820 秒)