作者:Meta 的 PyTorch 团队:Chunzhi Yang, Rich Zhu, Zain Huda, Liangbei Xu, Xin Zhang, Jiyan Yang, Dennis van der Staay, Wang Zhou, Jin Fang, Jade Nie, Yuxi Hu

在 Meta,推荐系统是向全球数十亿用户提供相关个性化广告的基石。通过 PyTorch 的 TorchRec 等技术,我们成功开发了能够在数百个 GPU 上进行模型训练的解决方案。虽然这些系统一直运行良好,但最近关于缩放定律(scaling laws)的研究揭示了一个引人注目的机会:通过训练规模显著增大的神经网络,我们可以获得显著更好的模型性能。

然而,这一发现也带来了新的挑战。我们目前的训练基础设施虽然针对数百个 GPU 进行了高度优化,但无法有效扩展到训练这些更大模型所需的数千个 GPU。从数百个 GPU 跃升到数千个 GPU 引入了复杂的技术挑战,尤其是在处理推荐模型中的稀疏操作方面。这些挑战需要全新的分布式训练方法,而我们通过一种新颖的并行化策略来解决这些问题。

为了解决这些问题,我们引入了 2D embedding parallel(2D 嵌入并行),这是一种新颖的并行化策略,克服了在数千个 GPU 上训练大型推荐模型固有的稀疏扩展挑战。目前可通过 TorchRec 中的 DMPCollection API 使用此功能。 这种方法结合了两种互补的并行化技术:对模型稀疏组件采用数据并行,以及对嵌入表采用模型并行,并充分利用了 TorchRec 强大的分片能力。通过战略性地整合这些技术,我们创建了一种可扩展到数千个 GPU 的解决方案,该方案目前为 Meta 规模最大的推荐模型训练提供了支持。

稀疏扩展挑战是什么?

我们确定了阻碍我们将模型简单扩展到数千个 GPU 的三个关键挑战

  • 不平衡和掉队问题: GPU 越多,实现均衡分片就越困难,一些 rank 的嵌入计算工作负载可能重得多,从而拖慢整个训练过程。
  • 跨节点通信: 随着训练作业使用的 GPU 数量增加,在某些网络拓扑下,all-to-all 通信带宽可能会下降,从而显著增加通信延迟。
  • 内存开销: 输入特征占用的内存通常可以忽略不计,但当我们使用数千个 GPU 时,可能会引入更大的输入特征,此时内存需求会变得非常大。

使用 2D 嵌入并行,我们可以这样描述我们的新并行方案。在这个例子中,我们有 2 个模型副本(副本 1:GPU1/GPU3,副本 2:GPU2/GPU4)

Flow diagram

图 1:2D 稀疏并行布局图示

通过 2D 稀疏并行,我们解决了这些挑战。我们不是将表分片到所有 rank 上,而是首先将所有 rank 平均分成若干并行组

  1. 在每个组内,我们对嵌入表使用模型并行,例如列式/行式分片。在大规模应用中,对于我们最大的表,我们还开发了一种网格分片(grid sharding),它在行和列维度上对嵌入表进行分片。
  2. 跨组则采用数据并行,使得一个组中的每个 rank 在其他组中都有其对应的副本 rank(副本 rank 意味着存储相同的嵌入表分片)。
    1. 每个组完成自己的反向传播后,我们在副本之间对嵌入表权重执行 all-reduce 操作以保持同步。

我们的生产解决方案

TorchRec 是我们在原生 PyTorch 中构建推荐模型稀疏部分的库。传统的 API 是 DistributedModelParallel,它对嵌入表应用模型并行。我们同时引入了一个新的 API,称为 DMPCollection,它是 TorchRec 模型上启用 2D 并行的主要入口。我们设计它的目标是使其更改就像应用 FSDP/DDP 一样简单。

要理解 DMPCollection 的作用,我们首先必须理解 DistributedModelParallel (DMP) 的作用

  1. 创建嵌入表,称为 EmbeddingBagCollection 和 EmbeddingCollections。
  2. 根据 GPU 拓扑、嵌入表、可用内存、输入数据等生成分片计划。
  3. 使用 DMP 包装模型,并传入相关的分片计划。
  4. DMP 根据分片计划初始化并分片嵌入表。
  5. 在训练步骤中,DMP 接收输入批次,将其通信给包含相关嵌入表分片的适当 GPU,查找值,然后将其返回给请求该值的 GPU。这一切都在全局进程组(global process group)上完成,特殊分片(如表行式分片)除外。

DistributedModelParallel 是为模型并行构建的,其许多部分都在分片和围绕全局 world size 的假设下工作。我们需要以一种方式修改这些部分,以便在不损失 TorchRec 优化和功能集的情况下引入额外的并行维度。

DMPCollection 修改了一些关键部分,以一种可扩展的方式启用 2D 并行,

  • 一次性为较小的分片组生成分片计划,一旦传入,我们就通过全局组与适当的 rank 进行通信,并重新映射 rank 以适应新的分片组 rank。
  • 创建两个新的 NCCL 进程组,称为分片进程组(sharding process group)和副本进程组(replica process group)。分片进程组被传递到 TorchRec 的分片和训练步骤组件中。副本进程组用于权重和优化器状态同步,all reduce 调用在此进程组上进行。
    • 子 NCCL 进程组使我们能够高效地仅在特定通信相关的 rank 之间进行通信。每个 rank 将拥有两个关联的进程组。

对用户来说,更改非常简单,同时消除了将并行化策略应用于模型的所有复杂性。

我们如何创建这些分片组和复制组?

这些进程组是 DMPCollection 实现高性能的关键之一。从我们之前的图示中,我们展示了一个简单的 2x2 GPU 设置,但在大规模情况下,我们如何确定哪些 rank 属于某个特定的分片组,以及它们在不同分片组中的副本 rank 是什么?

考虑以下设置:2 个节点,每个节点有 4 个 GPU。在 2D 并行下,分片组和复制组将是,

分片组 分片 Rank
0 0, 2, 4, 6
1 1, 3, 5, 7
复制组 复制 Rank
0 0, 1
1 2, 3
2 4, 5
3 6, 7

我们使用以下公式,

  1. 将所有训练器(trainer)分为 G 个分片组,每个组包含 L 个训练器
    1. 组数 G 由 G = T / L 决定,其中 T 是训练器总数
  2. 对于每个组 G,我们根据其所属组分配非连续的训练器 rank,规则如下,
    1. [i, G+i, 2G+i, …, (L - 1) G+i],其中* i = 0 到 G-1*
  3. 从组 G 中,我们可以创建复制组,即每 G 个连续的 rank
    1. (0 到 G-1,G 到 2* G - 1) 每个连续集合存储重复的嵌入表分片。

这意味着我们的分片组 G 的大小为 L,可以认为是应用模型并行所跨越的 rank 数量。这反过来又给了我们大小为 G 的复制组,这些组是我们进行数据并行所跨越的 rank。

在 DMPCollection 中,我们能够利用 DeviceMesh 有效地创建这些进程组。我们在一个 2x2 矩阵中创建了整个 GPU 拓扑,其中每一行代表分片 rank 组,每一列代表相应的副本 rank,

create peer matrix
num_groups = global_world_size // sharding_group_size
for each group_rank in num_groups:
	peers = [num_groups * rank + group_rank for rank in range(sharding_group_size)]
	add peer to peer matrix

initalize DeviceMesh with two dimensions (shard, replicate)
slice DeviceMesh on shard for sharding process group
slide DeviceMesh on replicate for replica process group

通过我们的 DeviceMesh 方法,如果将来需要更改拓扑结构或提供进一步的灵活性,我们可以轻松地将创建逻辑扩展到任何形式的拓扑结构,甚至在需要时扩展到更多的并行维度。

2D 并行的性能

我们的 rank 分区策略通过在同一计算节点内策略性地放置每个分片的模型副本 rank 来优化通信模式。这种架构为权重同步操作带来了显著的性能优势。在反向传播之后,我们执行 all-reduce 操作来同步模型权重——考虑到我们需要通信和同步的大量参数,这是一个开销很大的过程——通过将副本放置在同一节点上的设置,我们利用了节点内的高带宽,而不是依赖较慢的跨节点带宽。

这种设计选择对其他通信集合体(communication collectives)的影响通常会改善延迟。改进主要源于两个因素。

  1. 通过在较少数量的 rank 上对嵌入表进行分片并在较小的组内进行模型通信,我们实现了较低的 all-to-all 延迟。
  2. 通过 2D 并行中的复制,我们在单个 rank 上的嵌入查找延迟降低了。我们可以将本地批次大小减少到等效全局批次大小的 1/N,其中 N 是模型副本的数量。

一个生产模型的跟踪(trace)示例说明了这两个因素。在这里,我们在 1024 个 GPU 上运行 2D 并行作业,分片组大小为 256 个 GPU。

State diagram

图 2:非 2D 并行与 2D 并行工作负载的延迟比较

用户有两个关键杠杆可以调整以最大限度地提升工作负载性能

  1. 模型分片组的大小相对于全局 world size。全局 world size 除以分片组大小表示我们将拥有的模型副本数量。
    1. 为了最大限度地提升性能,用户可以考虑将模型扩展到 8 倍。这个扩展因子维持了主机内(intra-host)的 all reduce。
      1. 要实现进一步的扩展,all reduce 必须跨主机(inter-host)进行。从我们的实验来看,我们没有看到明显的性能下降,实际上还注意到了跨主机 all reduce 的优势。我们可以将分片和复制拓扑更改为跨主机 all reduce,这有助于我们在特定主机发生故障时引入容错策略。
  2. all reduce 同步频率。DMPCollection 提供了 sync() 调用,可以调整为每 N 个训练步调用一次,进行一种局部 SGD 训练。随着规模扩大,降低同步频率可以显著提升性能。

未来工作

读者应该注意,2D 稀疏并行训练与非并行化训练不同之处在于我们同步的是嵌入表权重而不是梯度。这种方法得益于 TorchRec 对 FBGEMM 的使用,FBGEMM 在底层提供了优化的内核。FBGEMM 的一个关键优化是反向传播中优化器的融合。嵌入表梯度没有完全具体化(这将消耗大量内存),而是直接传递给优化器更新。尝试具体化和同步这些梯度会产生巨大的开销,使得这种方法不切实际。

我们的探索发现,要达到与基线(baseline)相当的训练结果,我们需要延迟同步优化器状态,同步时机取决于分片/复制组的数量(例如:对于 Adagrad,我们滞后一个同步步长来更新动量)。这种方法还使用户能够实现局部 SGD 或半同步训练策略,这些策略可以实现收敛,并有可能产生比基线更好的损失曲线。

感谢您阅读我们的文章!这是一个令人兴奋的方向,我们希望进一步发展,以最大限度地提高推荐系统的性能并推动技术的前沿。