跳转到主要内容

在Meta,推荐系统是向全球数十亿用户提供相关和个性化广告的基石。通过像 PyTorch 的 TorchRec 这样的技术,我们成功开发了能够在数百个 GPU 上进行模型训练的解决方案。尽管这些系统运行良好,但近期关于扩展定律的研究揭示了一个令人兴奋的机会:通过训练显著更大的神经网络,我们可以实现更好的模型性能。

然而,这一洞察力给我们带来了新的挑战。我们当前的训练基础设施,虽然已针对数百个 GPU 进行了高度优化,但无法有效地扩展到训练这些更大模型所需的数千个 GPU。从数百个 GPU 跃升到数千个 GPU 引入了复杂的技术挑战,尤其是在处理推荐模型中的稀疏操作方面。这些挑战需要全新的分布式训练方法,我们通过一种新颖的并行策略来解决这些问题。

为了解决这些问题,我们引入了 2D 嵌入并行,这是一种新颖的并行策略,克服了在数千个 GPU 上训练大型推荐模型固有的稀疏扩展挑战。目前,通过 TorchRec 的 DMPCollection API 即可使用此功能。 这种方法结合了两种互补的并行化技术:针对模型稀疏组件的数据并行和针对嵌入表的模型并行,利用了 TorchRec 强大的分片能力。通过战略性地整合这些技术,我们创建了一个可扩展到数千个 GPU 的解决方案,目前它正在为 Meta 最大的推荐模型训练运行提供支持。

稀疏扩展的挑战是什么?

我们确定了三个主要挑战,这些挑战阻碍了我们天真地将模型扩展到数千个 GPU:

  • 不平衡和慢节点问题: 随着 GPU 数量的增加,实现平衡分片变得更加困难,一些节点可能承担更重的嵌入计算负载,这会减慢整个训练过程。
  • 跨节点通信: 随着训练作业使用更多 GPU,在某些网络拓扑下,all-to-all 通信带宽可能会下降,这会显著增加通信延迟。
  • 内存开销: 输入特征使用的内存通常可以忽略不计,但是,当我们使用数千个 GPU 时,我们可以引入更大的输入特征,内存需求可能会变得非常大。

通过 2D 嵌入并行,我们可以这样描述我们的新并行方案,在此示例中,我们有 2 个模型副本(副本 1:GPU1/GPU3,副本 2:GPU2/GPU4)

Flow diagram

图1:2D稀疏并行布局示意图

通过 2D 稀疏并行,我们解决了这些挑战,不再将表分片到所有等级,而是首先将所有等级均匀划分为几个并行组。

  1. 在每个组中,我们对嵌入表使用模型并行,例如列式/行式分片。在大规模部署中,对于我们最大的表,我们还开发了网格分片,它在行和列维度上对嵌入表进行分片。
  2. 跨组进行数据并行,使得一个组中的每个秩在其他组中都有其对应的副本秩(副本秩意味着存储相同的嵌入表分片)。
    1. 每个组完成自己的反向传播后,我们对副本之间的嵌入表权重进行 All-Reduce 操作,以保持它们同步。

我们的生产解决方案

TorchRec 是我们用于在原生 PyTorch 中构建推荐模型稀疏部分的库。传统的 API 是 DistributedModelParallel,它将模型并行应用于嵌入表。我们引入了一个新的 API,称为 DMPCollection,它作为在 TorchRec 模型上启用 2D 并行的主要入口点。我们将其设计得像应用 FSDP/DDP 一样简单。

要理解 DMPCollection 的作用,我们必须首先理解 DistributedModelParallel (DMP) 的作用

  1. 创建嵌入表,称为 EmbeddingBagCollection 和 EmbeddingCollections。
  2. 生成关于 GPU 拓扑、嵌入表、可用内存、输入数据等的分片计划。
  3. 使用 DMP 和传入的相关分片计划包装模型。
  4. DMP 根据分片计划初始化和分片嵌入表。
  5. 在训练步骤中,DMP 接收输入批次,将其通信到包含感兴趣的嵌入表分片的相应 GPU,查找值,然后将其返回给请求它的 GPU。所有这些都在全局进程组上完成,特殊分片(例如表行式分片)除外。

DistributedModelParallel 是为模型并行构建的,其许多部分都假设分片并围绕全局世界大小工作。我们需要以一种方式改变这些部分,以便我们可以引入额外的并行维度,而不会丢失 TorchRec 的优化和功能集。

DMPCollection 更改了几个关键部分,以可扩展的方式启用 2D 并行,

  • 一次为较小的分片组生成分片计划,一旦传入,我们就会将计划通信到全局组中相应的进程,并重新映射进程以适应新的分片组进程。
  • 创建两个新的 NCCL 进程组,称为分片进程组和副本进程组。分片进程组被传递到 TorchRec 的分片和训练步骤组件中。副本进程组用于权重和优化器状态同步,all reduce 调用发生在此进程组上。
    • 子 NCCL 进程组允许我们只在与特定通信相关的进程之间进行高效通信。每个进程将有两个关联的进程组。

对于用户来说,这种改变非常简单,同时消除了将并行策略应用于模型的所有复杂性。

我们如何创建这些分片组和复制组?

这些进程组是 DMPCollection 高性能实现的关键之一。从我们之前的图中,我们展示了一个简单的 2×2 GPU 设置,然而,在大规模部署中,我们如何分配哪些进程属于给定的分片组,以及它们在不同分片组中的副本进程是什么?

考虑以下设置:2 个节点,每个节点有 4 个 GPU。2D 并行下的分片组和复制组将是:

分片组分片秩00、2、4、611、3、5、7 复制组复制秩00、112、324、536、7

我们使用以下公式,

  1. 将所有训练器分成 G 个分片组,每个组有 L 个训练器。
    1. 组数 G 由 G = T / L 确定,其中 T 是训练器总数。
  2. 对于每个组 G,我们根据其所在的组分配非连续的训练器秩,遵循:
    1. [i, G+i, 2G+i, …, (L – 1) G+i],其中 * i = 0 到 G-1 *
  3. 从这些组G中,我们可以创建复制组,即每G个连续的秩
    1. (0 到 G-1,G 到 2*G-1)每个连续集存储重复的嵌入表分片。

这意味着我们的分片组 G 的大小为 L,可以理解为需要应用模型并行的进程数量。这反过来又为我们提供了复制组,每个组的大小为 G,这些是我们需要进行数据并行的进程。

在 DMPCollection 中,我们可以利用 DeviceMesh 高效地创建这些进程组,我们将整个 GPU 拓扑创建为一个 2x2 矩阵,其中每行代表分片进程组,每列代表相应的副本进程组。

create peer matrix
num_groups = global_world_size // sharding_group_size
for each group_rank in num_groups:
	peers = [num_groups * rank + group_rank for rank in range(sharding_group_size)]
	add peer to peer matrix

initalize DeviceMesh with two dimensions (shard, replicate)
slice DeviceMesh on shard for sharding process group
slide DeviceMesh on replicate for replica process group

通过我们的 DeviceMesh 方法,如果未来我们想改变拓扑或提供更大的灵活性,我们可以轻松地将我们的创建逻辑扩展到任何形式的拓扑,甚至在需要时扩展到更多的并行维度。

2D 并行的性能

我们的秩分区策略通过将每个分片的模型副本秩战略性地放置在同一个计算节点内,从而优化了通信模式。这种架构为权重同步操作提供了显著的性能优势。在反向传播之后,我们执行 all-reduce 操作以同步模型权重——考虑到我们需要通信和同步的大量参数,这是一个昂贵的过程——通过将副本放置在同一个节点上,我们利用了节点内部的高带宽,而不是过度依赖较慢的节点间带宽。

这种设计选择对其他集体通信的影响通常会改善延迟。这种改进源于两个因素。

  1. 通过将嵌入表分片到更少的秩上,并在较小的组内进行模型通信,我们实现了更低的 all-to-all 延迟。
  2. 通过2D并行中的复制,我们的嵌入查找延迟在一个秩上减少了,我们可以将本地批次大小减少到等效全局批次大小的1/N,其中N是模型副本的数量。

一个生产模型跟踪示例说明了这两个因素,这里我们在一千零二十四核 GPU 上运行 2D 并行作业,分片组大小为两百五十六核 GPU。

State diagram

图2:非2D并行与2D并行工作负载的延迟比较

用户可以通过两个关键杠杆来优化其工作负载的性能:

  1. 模型分片组的大小相对于全局世界大小。全局世界大小除以分片组大小代表我们将拥有的模型副本数量。
    1. 为了最大限度地提高性能,用户可以将模型扩展到 8 倍,此扩展因子可保持节点内 all-reduce。
      1. 为了进一步扩展,all reduce 必须在主机间进行。根据我们的实验,我们没有看到明显的性能退化,事实上,我们注意到了主机间 all reduce 的优势。我们可以改变分片和副本拓扑以进行主机间 all reduce,这有助于我们在特定主机宕机时引入容错策略。
  2. all reduce 同步的频率,DMPCollection 带有 sync() 调用,可以调整为每 N 个训练步调用一次,执行一种本地 SGD 训练。随着规模的扩大,降低同步频率可以显著提升性能。

未来工作

读者应该注意,2D 稀疏并行训练与非并行训练不同,因为我们同步的是嵌入表权重而不是梯度。这种方法之所以可行,得益于 TorchRec 使用 FBGEMM,后者在底层提供了优化的内核。FBGEMM 的一个关键优化是在反向传播中融合优化器。它不是完全实例化嵌入表梯度(这会消耗大量内存),而是将它们直接传递给优化器更新。尝试实例化和同步这些梯度会产生大量开销,使该方法不切实际。

我们的探索表明,为了获得与基线相当的训练结果,我们以延迟调度的方式同步优化器状态,其时机取决于分片/副本组的数量(例如:对于 Adagrad,我们落后一个同步步骤更新动量)。这种方法还允许用户实现本地 SGD 或半同步训练策略,这可以实现收敛并可能产生比基线更好的损失曲线。

感谢您的阅读!我们发现了一个令人兴奋的方向,希望进一步发展,以最大化推荐系统的性能并推动最先进技术的发展。