在过去一年中,混合专家 (MoE) 模型越来越受欢迎,这得益于强大的开源模型,如 DBRX、Mixtral、DeepSeek 等等。在 Databricks,我们与 PyTorch 团队紧密合作,以扩展 MoE 模型的训练规模。在这篇博文中,我们将讨论如何使用 PyTorch Distributed 和 MegaBlocks(PyTorch 中高效的开源 MoE 实现)扩展到超过三千个 GPU。
什么是 MoE?
MoE 模型是一种使用多个专家网络进行预测的模型架构。门控网络用于路由和组合专家的输出,确保每个专家都在不同的、专门的 token 分布上进行训练。基于 Transformer 的大型语言模型的架构通常由一个嵌入层和多个 Transformer 块组成(图 1,子图 A)。每个 Transformer 块包含一个注意力块和一个稠密前馈网络(图 1,子图 B)。这些 Transformer 块堆叠在一起,使得一个 Transformer 块的输出成为下一个块的输入。最终输出通过一个全连接层和 Softmax 函数,以获得输出下一个 token 的概率。
当在 LLM 中使用 MoE 时,稠密前馈层被 MoE 层取代,MoE 层由一个门控网络和若干专家组成(图 1,子图 D)。门控网络通常是一个线性前馈网络,它接收每个 token 并生成一组权重,这些权重决定了哪些 token 被路由到哪些专家。专家本身通常也实现为前馈网络。在训练期间,门控网络适应于将输入分配给专家,使模型能够专门化并提高其性能。然后,路由器输出用于加权专家输出,以给出 MoE 层的最终输出。
图 1:在 Transformer 块中使用混合专家模型
与稠密模型相比,对于给定的计算预算,MoE 模型提供更高效的训练。这是因为门控网络仅将 token 发送到一部分专家,从而减少了计算负载。因此,模型的容量(其参数总数)可以增加,而不会成比例地增加计算需求。在推理期间,仅使用部分专家,因此 MoE 能够执行比稠密模型更快的推理。但是,整个模型需要加载到内存中,而不仅仅是正在使用的专家。
MoE 中实现更高计算效率的稀疏性来自于这样一个事实:特定的 token 只会被路由到一部分专家。专家的数量以及如何选择专家取决于门控网络的实现,但一种常见的方法是 top k。门控网络首先预测每个专家的概率值,然后将 token 路由到 top k 个专家以获得输出。但是,如果所有 token 始终都发送到相同的专家子集,则训练效率会降低,而其他专家最终会训练不足。为了缓解这个问题,引入了负载均衡损失,以鼓励均匀路由到所有专家。
专家的数量和选择 top k 个专家是设计 MoE 的重要因素。更多的专家数量允许扩展到更大的模型,而不会增加计算成本。这意味着模型具有更高的学习能力,但是,超过某个点后,性能增益往往会减少。选择的专家数量需要与服务模型的推理成本相平衡,因为整个模型都需要加载到内存中。同样,在选择 top k 时,训练期间较低的 top k 会导致较小的矩阵乘法,如果通信成本足够大,则会浪费计算资源。但是,在推理期间,较高的 top k 通常会导致较慢的推理速度。
MegaBlocks
MegaBlocks 是一种高效的 MoE 实现,它使用稀疏矩阵乘法来并行计算专家输出,尽管 token 分配不均匀。MegaBlocks 实现了无丢弃的 MoE,避免了丢弃 token,同时使用 GPU 内核来保持高效训练。在 MegaBlocks 之前,动态路由公式迫使人们在模型质量和硬件效率之间进行权衡。以前,用户要么必须从计算中丢弃 token,要么在填充上浪费计算和内存。专家可以接收可变数量的 token,并且可以使用块稀疏矩阵乘法有效地执行专家计算。我们已将 MegaBlocks 集成到 LLM Foundry 中,以支持将 MoE 训练扩展到数千个 GPU。
图 2:专家计算的矩阵乘法
专家并行
随着模型扩展到更大的规模并且无法容纳在单个 GPU 上,我们需要更高级的并行形式。专家并行是一种模型并行形式,我们将不同的专家放置在不同的 GPU 上以获得更好的性能。专家权重不会在所有 GPU 之间通信,而是将 token 发送到包含专家的设备。通过移动数据而不是权重,我们可以聚合跨多台机器的数据以用于单个专家。路由器确定输入序列中的哪些 token 应发送给哪些专家。这通常通过计算每个 token-专家对的门控分数来完成,然后将每个 token 路由到得分最高的专家。一旦确定了 token 到专家的分配,就会执行一次 all-to-all 通信步骤,以将 token 分派到托管相关专家的设备。这涉及每个设备发送分配给其他设备上专家的 token,同时接收分配给其本地专家的 token。
专家并行的主要优势是处理少量、较大的矩阵乘法,而不是多个小的矩阵乘法。由于每个 GPU 仅具有一部分专家,因此它只需要为这些专家进行计算。相应地,随着我们聚合跨多个 GPU 的 token,每个矩阵的大小也成比例地增大。由于 GPU 针对大规模并行计算进行了优化,因此较大的操作可以更好地利用其功能,从而提高利用率和效率。有关较大矩阵乘法的优势的更深入解释,请参见 此处。计算完成后,将执行另一次 all-to-all 通信步骤,以将专家输出发送回其原始设备。
图 3:专家并行中的 token 路由
我们利用 PyTorch 的 DTensor(一种用于描述 tensor 如何分片和复制的低级抽象)来有效地实现专家并行。我们首先手动将专家放置在不同的 GPU 上,通常跨节点分片以确保在路由 token 时可以利用 NVLink 进行快速 GPU 通信。然后,我们可以在此布局之上构建一个 设备网格,这使我们能够简洁地描述整个集群中的并行性。当我们需要其他形式的并行性时,我们可以使用此设备网格轻松地检查点或重新排列专家。
使用 PyTorch FSDP 扩展 ZeRO-3
与专家并行结合使用,我们将数据并行用于所有其他层,其中每个 GPU 存储模型和优化器的副本,并处理不同的数据块。在每个 GPU 完成前向和后向传递后,梯度将在 GPU 之间累积以进行全局模型更新。
ZeRO-3 是一种数据并行形式,其中权重和优化器在每个 GPU 之间分片,而不是被复制。现在,每个 GPU 仅存储完整模型的一部分,从而大大降低了内存压力。当计算需要模型的一部分时,它会在所有 GPU 之间收集,并且在计算完成后,将丢弃收集的权重。我们使用 PyTorch 的 ZeRO-3 实现,称为 完全分片数据并行 (FSDP)。
随着我们扩展到数千个 GPU,跨设备通信的成本增加,从而减慢了训练速度。由于需要在所有 GPU 之间同步和共享模型参数、梯度和优化器状态(涉及 all-gather 和 reduce-scatter 操作),通信量会增加。为了在保持 FSDP 优势的同时缓解此问题,我们利用混合分片数据并行 (HSDP) 在一组 GPU 之间分片模型和优化器,并多次复制此副本以充分利用集群。使用 HSDP,在后向传递中需要额外的 all reduce 操作来同步跨副本的梯度。这种方法使我们能够在大规模分布式训练期间平衡内存效率和通信成本。要使用 HSDP,我们可以扩展我们之前的专家并行设备网格,并让 PyTorch 完成在需要时实际分片和收集的繁重工作。
图 4:FSDP 和 HSDP
借助 PyTorch,我们可以有效地结合这两种类型的并行性,利用 FSDP 的更高级别 API,同时在我们想要实现自定义功能(如专家并行)时使用较低级别的 DTensor 抽象。我们现在有一个 3D 设备网格,具有专家并行分片维度、ZeRO-3 分片维度和一个用于纯数据并行的副本维度。这些技术共同实现了在超大型集群上的近线性扩展,使我们能够实现超过 40% 的 MFU 数字。
使用 Torch Distributed 进行弹性检查点
容错对于确保 LLM 可以在较长时间内可靠地训练至关重要,尤其是在节点故障很常见的分布式环境中。为了避免在作业不可避免地遇到故障时丢失进度,我们检查点模型的状态,其中包括参数、优化器状态和其他必要的元数据。当发生故障时,系统可以从上次保存的状态恢复,而不是重新开始。为了确保对故障的鲁棒性,我们需要经常检查点,并以尽可能高性能的方式保存和加载检查点,以最大限度地减少停机时间。此外,如果太多 GPU 发生故障,我们的集群大小可能会更改。因此,我们需要能够在不同数量的 GPU 上弹性恢复。
PyTorch 通过其分布式训练框架支持弹性检查点,其中包括用于跨不同集群配置保存和加载检查点的实用程序。PyTorch 分布式检查点确保模型的状态可以在训练集群中的所有节点上并行地准确保存和恢复,而无需考虑由于节点故障或添加而导致的集群组成的变化。
此外,当训练非常大的模型时,检查点的大小可能会非常大,从而导致非常慢的检查点上传和下载时间。PyTorch 分布式检查点支持分片检查点,这使每个 GPU 仅保存和加载其模型部分。当将分片检查点与弹性训练结合使用时,每个 GPU 读取元数据文件以确定在恢复时要下载哪些分片。元数据文件包含有关每个 tensor 的哪些部分存储在每个分片中的信息。然后,GPU 可以下载其模型部分的切片并加载检查点的该部分。
图 5:检查点保存和恢复在其他 GPU 上重新分片
通过跨 GPU 并行化检查点,我们可以分散网络负载,从而提高鲁棒性和速度。当使用 3000 多个 GPU 训练模型时,网络带宽很快就会成为瓶颈。我们利用 HSDP 中的复制来首先在一个副本上下载检查点,然后将必要的分片发送到其他副本。通过我们在 Composer 中的集成,我们可以可靠地将检查点上传到云存储,频率高达每 30 分钟一次,并在节点发生故障时在不到 5 分钟的时间内自动从最新的检查点恢复。
结论
我们非常高兴看到 PyTorch 如何以出色的性能支持最先进的 LLM 训练。在我们的帖子中,我们展示了如何通过 Pytorch Distributed 和 Foundry 上的 MegaBlocks 实现高效的 MoE 训练。此外,Pytorch 弹性检查点使我们能够在节点发生故障时在不同数量的 GPU 上快速恢复训练。使用 Pytorch HSDP 使我们能够有效地扩展训练规模并缩短检查点恢复时间。我们期待继续在一个强大而充满活力的开源社区的基础上进行构建,以帮助将出色的 AI 模型带给每个人。欢迎加入我们,在 LLM Foundry 和 PyTorch 中构建出色的模型。