在过去一年中,混合专家 (MoE) 模型人气飙升,得益于强大的开源模型,如 DBRX、Mixtral、DeepSeek 等。在 Databricks,我们与 PyTorch 团队紧密合作,扩展 MoE 模型的训练规模。在这篇博文中,我们将讨论如何使用 PyTorch 分布式 和 MegaBlocks(PyTorch 中一个高效的开源 MoE 实现)扩展到三千多个 GPU 进行训练。
什么是 MoE 模型?
MoE 模型是一种模型架构,它使用多个专家网络进行预测。一个门控网络用于路由和组合专家的输出,确保每个专家都在不同的、专门的 token 分布上进行训练。基于 Transformer 的大型语言模型的架构通常包含一个嵌入层,该层连接到多个 Transformer 块(图 1,子图 A)。每个 Transformer 块包含一个注意力块和一个密集前馈网络(图 1,子图 B)。这些 Transformer 块堆叠在一起,使得一个 Transformer 块的输出成为下一个块的输入。最终输出通过一个全连接层和 softmax 来获得下一个 token 输出的概率。
在 LLMs 中使用 MoE 时,密集前馈层被替换为一个 MoE 层,该层包含一个门控网络和多个专家(图 1,子图 D)。门控网络,通常是一个线性前馈网络,接收每个 token 并生成一组权重,这些权重决定了哪些 token 被路由到哪些专家。专家本身通常也实现为前馈网络。在训练期间,门控网络会调整以将输入分配给专家,从而使模型能够进行专门化并提高其性能。然后使用路由器的输出对专家的输出进行加权,以获得 MoE 层的最终输出。
图 1:在 Transformer 块中使用混合专家模型
与密集模型相比,在给定的计算预算下,MoE 模型提供了更高效的训练。这是因为门控网络只将 token 发送到专家子集,从而减少了计算负载。因此,可以在不按比例增加计算需求的情况下增加模型的容量(总参数数量)。在推理期间,只使用部分专家,因此 MoE 模型能够比密集模型执行更快的推理。然而,整个模型需要加载到内存中,而不仅仅是正在使用的专家。
MoE 模型实现更高计算效率的稀疏性源于一个事实:特定的 token 只会被路由到专家子集。专家数量以及如何选择专家取决于门控网络的实现,但一种常用方法是 top k。门控网络首先预测每个专家的概率值,然后将 token 路由到 top k 个专家以获得输出。然而,如果所有 token 总是被路由到同一个专家子集,训练就会变得低效,其他专家也会欠训练。为了缓解这个问题,引入了负载均衡损失,它鼓励将 token 均匀地路由到所有专家。
专家数量和选择 top k 个专家是设计 MoE 模型时的重要因素。更多的专家数量允许扩展到更大的模型,而无需增加计算成本。这意味着模型具有更高的学习能力,然而,超过某个点后,性能提升会趋于减小。选择的专家数量需要与模型推理成本相平衡,因为整个模型都需要加载到内存中。类似地,在训练期间选择较低的 top k 会导致矩阵乘法变小,如果通信成本足够大,这会留下未充分利用的计算资源。然而,在推理期间,较高的 top k 通常会导致推理速度变慢。
MegaBlocks
MegaBlocks 是一个高效的 MoE 实现,它使用稀疏矩阵乘法并行计算专家输出,尽管 token 分配可能不均匀。MegaBlocks 实现了一种无丢弃 (dropless) MoE,它在使用保持高效训练的 GPU kernel 的同时避免丢弃 token。在 MegaBlocks 之前,动态路由设计迫使人们在模型质量和硬件效率之间做出权衡。以前,用户不得不要么从计算中丢弃 token,要么在填充上浪费计算资源和内存。专家可以接收可变数量的 token,并且可以使用块稀疏矩阵乘法高效地执行专家计算。我们将 MegaBlocks 集成到 LLM Foundry 中,以实现 MoE 训练扩展到数千个 GPU。
图 2:专家计算的矩阵乘法
专家并行
随着模型规模变大,无法在单个 GPU 上容纳,我们需要更高级的并行化形式。专家并行是一种模型并行形式,我们将不同的专家放在不同的 GPU 上以获得更好的性能。与在所有 GPU 之间通信专家权重不同,token 被发送到包含专家的设备上。通过移动数据而不是权重,我们可以跨多个机器聚合数据以服务单个专家。路由器决定了输入序列中的哪些 token 应该发送到哪些专家。这通常是通过计算每个 token-专家对的门控分数,然后将每个 token 路由到得分最高的专家来完成的。一旦 token-专家分配确定,就会执行一步 all-to-all 通信,将 token 分派到托管相关专家的设备上。这涉及每个设备发送分配给其他设备上专家的 token,同时接收分配给其本地专家的 token。
专家并行的主要优势在于处理少数更大的矩阵乘法,而不是多个较小的矩阵乘法。由于每个 GPU 只拥有专家子集,它只需要为这些专家进行计算。相应地,随着我们在多个 GPU 之间聚合 token,每个矩阵的大小会按比例增大。由于 GPU 针对大规模并行计算进行了优化,更大的操作可以更好地利用其能力,从而提高利用率和效率。有关更大矩阵乘法优势的更深入解释,请参阅此处。计算完成后,会执行另一步 all-to-all 通信,将专家输出发送回其原始设备。
图 3:专家并行中的 token 路由
我们利用 PyTorch 的 DTensor(一种描述 tensor 如何分片和复制的低级抽象)来有效实现专家并行。我们首先手动将专家放置在不同的 GPU 上,通常跨节点进行分片,以确保在路由 token 时可以利用 NVLink 进行快速 GPU 通信。然后,我们可以在此布局之上构建一个设备网格,这使我们能够简洁地描述整个集群的并行性。当需要其他形式的并行化时,我们可以使用此设备网格轻松地进行检查点或重新排列专家。
使用 PyTorch FSDP 扩展 ZeRO-3
结合专家并行,我们对所有其他层使用数据并行,其中每个 GPU 存储模型和优化器的副本,并处理不同的数据块。在每个 GPU 完成前向和后向传播后,会在所有 GPU 之间累积梯度以进行全局模型更新。
ZeRO-3 是一种数据并行形式,其中权重和优化器在每个 GPU 之间分片,而不是复制。现在每个 GPU 只存储完整模型的子集,从而显著减少了内存压力。当计算需要模型的一部分时,它会在所有 GPU 之间进行收集,计算完成后,收集的权重将被丢弃。我们使用 PyTorch 的 ZeRO-3 实现,称为全分片数据并行 (FSDP)。
随着我们扩展到数千个 GPU,设备间的通信成本增加,减慢了训练速度。通信增加是因为需要在所有 GPU 之间同步和共享模型参数、梯度和优化器状态,这涉及 all-gather 和 reduce-scatter 操作。为了在保持 FSDP 优势的同时缓解此问题,我们利用混合分片数据并行 (HSDP) 将模型和优化器分片到一定数量的 GPU 上,并将此副本多次复制以充分利用集群。使用 HSDP,在反向传播中需要额外的 all reduce 操作来同步副本之间的梯度。这种方法使我们能够在大规模分布式训练期间平衡内存效率和通信成本。为了使用 HSDP,我们可以扩展之前从专家并行获得的设备网格,并让 PyTorch 在需要时完成实际分片和收集的繁重工作。
图 4:FSDP 和 HSDP
借助 PyTorch,我们可以有效结合这两种并行类型,在想要实现专家并行等自定义功能时,可以利用 FSDP 的更高级别 API,同时使用较低级别的 DTensor 抽象。现在我们拥有一个 3D 设备网格,包括专家并行分片维度、ZeRO-3 分片维度以及用于纯数据并行的复制维度。这些技术共同实现了在超大型集群上的近乎线性扩展,使我们能够达到超过 40% 的 MFU 值。
使用 Torch Distributed 实现弹性检查点
容错能力对于确保 LLMs 能够在长时间内可靠地训练至关重要,特别是在节点故障常见的分布式环境中。为了避免作业在不可避免地遇到故障时丢失进度,我们会对模型状态进行检查点,其中包含参数、优化器状态和其他必要的元数据。发生故障时,系统可以从最后保存的状态恢复,而不是从头开始。为了确保对故障的稳健性,我们需要经常进行检查点,并以性能最高的方式保存和加载检查点,以最大程度地减少停机时间。此外,如果太多 GPU 发生故障,集群规模可能会改变。因此,我们需要具备在不同数量的 GPU 上弹性恢复的能力。
PyTorch 通过其分布式训练框架支持弹性检查点,该框架包含用于跨不同集群配置保存和加载检查点的实用程序。PyTorch 分布式检查点确保模型状态可以并行地在训练集群中的所有节点上准确保存和恢复,无论由于节点故障或添加导致的集群组成发生任何变化。
此外,在训练超大型模型时,检查点的大小可能非常大,导致检查点上传和下载时间非常缓慢。PyTorch 分布式检查点支持分片检查点,这使得每个 GPU 只能保存和加载其模型的部分。当将分片检查点与弹性训练结合使用时,每个 GPU 会读取元数据文件,以确定恢复时需要下载哪些分片。元数据文件包含关于每个 tensor 的哪些部分存储在每个分片中的信息。然后,GPU 可以下载属于其模型部分的那些分片,并加载该部分的检查点。
图 5:检查点保存和在额外 GPU 上重新分片后恢复
通过在 GPU 之间并行进行检查点,我们可以分散网络负载,提高稳健性和速度。当使用 3000 多个 GPU 训练模型时,网络带宽很快成为瓶颈。我们利用 HSDP 中的复制功能,首先在一个副本上下载检查点,然后将必要的碎片发送到其他副本。通过与 Composer 的集成,我们可以可靠地每 30 分钟上传一次检查点到云存储,并在节点发生故障时在不到 5 分钟内自动从最新的检查点恢复。
结论
我们很高兴看到 PyTorch 如何以出色的性能支持最先进的 LLMs 训练。在我们的博文中,我们展示了如何在 Foundry 上通过 PyTorch 分布式和 MegaBlocks 实现高效的 MoE 训练。此外,PyTorch 弹性检查点使我们能够在节点发生故障时快速在不同数量的 GPU 上恢复训练。使用 PyTorch HSDP 使我们能够高效地扩展训练,并缩短检查点恢复时间。我们期待继续在一个强大而充满活力的开源社区的基础上构建,帮助将优秀的 AI 模型带给每个人。欢迎加入我们,在 LLM Foundry 和 PyTorch 上构建优秀的模型。