合作者:Less Wright, Howard Huang, Chien-Chin Huang, Crusoe: Martin Cala, Ethan Petersen
总结:我们使用torchft和torchtitan在一个真实的、具有极端合成故障率的环境中训练模型,以证明容错训练的可靠性和正确性。
在1200次故障中进行训练损失,无检查点。
注意:每个小尖峰都是非参与工作者恢复,这会影响指标但不会影响模型。
引言
我们希望通过以尽可能极端的故障率运行训练作业,在最坏情况下演示torchft。
大多数LLM预训练使用FSDP分片模型。torchft支持使用HSDP2的分片模型,它将分片模型与torchft的容错DDP all reduce结合起来。我们已将torchft集成到torchtitan中,因此您可以开箱即用地使用容错功能。torchft+titan还支持每个副本组内的其他分片/并行性,如张量并行性(TP)、流水线并行性(PP)等。
以下是使用torchft的训练作业结构
训练作业的结构。torchft的容错DDP实现在副本组之间用于同步梯度。标准FSDP2和其他并行化在每个副本组内使用。
torchft使用一个全局Lighthouse服务器和每个副本组的管理器来实时协调工作者。Lighthouse通过心跳了解所有工作者的状态以及哪些是健康的。
torchft实现了几种不同的容错算法。其中两个最主要的是:
- 容错HSDP:FSDPv2的扩展,使用容错all-reduce。这完全模拟了标准HSDP训练,每步进行梯度all_reduce和每步容错。这最适合具有快速后端网络(如infiniband)的大规模训练。
- LocalSGD/DiLoCo:半同步训练的容错实现。这些算法通过在指定间隔而不是像HSDP那样每一步进行同步来最小化通信开销。这通常用于通信受限的训练场景,例如通过以太网/TCP或地理位置分离的地点(联邦学习或多数据中心训练)。
我们一直在关注新的算法,例如即将支持的流式DiLoCo。如果您有新的用例想合作,请联系我们!
集群设置
Crusoe慷慨地借给我们一个包含300个L40S GPU的集群。这些GPU分布在30台主机上,每台主机有10个NVIDIA L40S GPU。
对于模型,我们使用了torchtitan和一个具有10亿参数的Llama 3模型,以匹配可用的硬件。
NVIDIA L40S GPU通常用于推理,因此为我们提供了一个机会,在非传统环境中测试torchft,其中DiLoCo等功能由于较低的仅TCP(无infiniband/nvlink)网络瓶颈而真正发挥作用。L40S具有48GB的VRAM(更接近消费级GPU),因此我们使用了更小的模型和批量大小。训练的平均步长约为9秒。
为了最大限度地提高有限网络下的性能,我们以30x1x10的配置训练模型。我们有30个副本组(容错域),每个组有1台主机和10个GPU/工作节点。torchft可以在每个副本组中拥有许多主机,但对于此集群,由于网络带宽有限,每个副本组一个主机/10个GPU的性能最佳。我们运行了30个副本组,因为更多的组对协调和重新配置算法的压力更大。
对于网络通信,我们在每个副本组内部的所有通信(即FSDP)中使用NCCL,在副本组之间使用Gloo。Gloo虽然性能通常不如NCCL,但初始化速度快得多,并且失败检测也快得多,这对于快速检测故障很重要。torchft确实支持使用NCCL在IB集群上实现容错,但有一些注意事项,在此演示中未使用。由于我们希望最大化故障和恢复的总数,我们使用了Gloo,因为它在我们的用例中可以在<1秒内重新初始化,并且我们能够将所有操作的超时设置为5秒。
对于容错算法,我们主要使用容错HSDP进行测试,因为它对通信和仲裁层施加的压力最大。对于最终测试,我们使用了DiLoCo,它更适合基于以太网的集群。
无需检查点即可恢复
传统的机器学习通过在发生错误时从检查点重新加载来实现“容错”。这涉及到完全的“停止世界”操作,所有工作节点都会重新启动并从最近持久化的检查点加载。
使用torchft,我们转而专注于将故障隔离到单个GPU组。当该组内发生错误时,我们可以异步重启该组,而所有其他组可以重新配置并继续训练而无需等待该组。
当该组通过重启或调度器替换机器而恢复时,这些工作节点不再拥有有效的权重和优化器状态副本。如果尝试从检查点恢复,其他组已经继续进行了。相反,我们依赖于运行时异步权重传输。这通过点对点传输从健康的副本获取权重。
由于我们总是从另一个工作节点恢复——事实证明,只要我们能保证至少有一个组是健康的,我们实际上不需要任何检查点。对于这个演示,我们完全关闭了检查点功能,因为持久检查点的保存和加载比我们的P2P恢复时间长得多。
这张图显示了一个正在恢复的副本(副本1)如何加入仲裁并从一个健康的对等副本(副本0)恢复,而不会造成任何停机时间或影响健康工作节点的训练。
torchft借鉴了分布式数据库的许多概念
- 仲裁操作通过频繁的心跳确定哪些工作节点是健康的,并确保我们能够快速确定哪些工作节点在线,以容错的方式交换元数据,并强制执行无脑裂条件。
- 为了确保一致性并确定何时需要恢复工作节点,我们实际上将训练视为具有传统数据库语义。传统数据库使用“事务”,其中每个操作要么提交(完全应用),要么回滚(丢弃)。torchft以同样的方式处理每个训练步骤。副本组中的每个训练步骤都作为分布式事务处理,我们确保所有工作节点通过步进优化器来提交步骤,或者如果发生错误,它们都通过丢弃梯度来回滚。
更多详情,请参阅torchft README,其中包含文档、设计文档和演示的链接。
训练循环集成
TorchFT 已经与 TorchTitan 集成,因此,启用它只需设置一个配置标志。对于典型的模型,torchft 提供了封装器,这些封装器会自动调用 torchft 管理器中的钩子,以提供容错功能。
from torchft import Manager, DistributedDataParallel, Optimizer, ProcessGroupGloo # Instantiate your model and optimizer as normal m = nn.Linear(2, 3) optimizer = optim.AdamW(m.parameters()) # Setup torchft Manager and wrap the model and optimizer. manager = Manager( pg=ProcessGroupGloo(), load_state_dict=lambda state_dict: m.load_state_dict(state_dict), state_dict=lambda: m.state_dict(), ) m = DistributedDataParallel(manager, m) optimizer = Optimizer(manager, optimizer) for batch in dataloader: # When you call zero_grad, we start the asynchronous quorum operation # and perform the async weights recovery if necessary. optimizer.zero_grad() out = m(batch) loss = out.sum() # The gradient allreduces will be done via torchft's fault tolerant # ProcessGroupGloo wrapper. loss.backward() # The optimizer will conditionally step depending on if any errors occured. # The batch will be discarded if the gradient sync was interrupted. optimizer.step()
容错调度
我们可以使用标准的ML作业调度器,如Slurm,因为副本组内工作节点的语义与普通作业相同。如果组内任何工作节点发生错误,我们期望整个组同时重启。在每个副本组内,应用程序是使用标准非容错操作的完全标准训练作业。
为了在传统调度器上实现容错,我们运行了多个这样的作业。每个副本组都在Slurm上作为一个独立的训练作业运行,Lighthouse和监控脚本在头节点上运行。所有跨组通信都通过torchft管理的ProcessGroup和仲裁API完成。为了在故障时重启组和注入故障,我们使用了一个使用torchx Python API的小脚本。
监控脚本大致如下:
from torchx.runner import get_runner NUM_REPLICA_GROUPS = 30 with get_runner() as runner: while True: jobs = runner.list(scheduler) active_replicas = { parse_replica_id(job.name) for job in jobs if not job.is_terminal() } missing_replicas = set(range(NUM_REPLICA_GROUPS)) - active_replicas for replica_id in missing_replicas: app_def = make_app_def(replica_id=replica_id) app_handle = runner.run( app_def, scheduler="slurm", cfg={"partition": "batch"}, ) print("launched:", replica_id, app_handle) time.sleep(5.0)
通过使用scancel取消特定副本组的Slurm作业来注入故障。在真实世界场景中,我们期望故障是由训练过程中的错误触发的,该错误将导致该副本组独立崩溃,而不是外部故障。
指标和日志
为了确保我们对作业有一致的视图,我们避免向一个副本组注入故障,以便更简单地跟踪作业的指标和仲裁事件。该组能够持续记录参与者数量、步骤成功/失败和损失。
由于我们正在执行每步容错,参与者数量以及批量大小会根据哪些工作节点健康而每步变化。
损失通过跨副本组的allreduce在作业中的所有工作节点/副本组之间平均。
注意:下图损失曲线中的小尖峰是由于我们平均所有主机(包括恢复中的工作节点)的损失的方式造成的,这些工作节点的权重已过时,导致这些步骤的损失错误地更高。
运行
我们运行了三个不同的测试,展示了torchft的各种故障场景和功能。
运行1:每60秒注入一次故障,共1100次故障
这次运行持续了19个多小时,共6249步。平均而言,每一步耗时10.9秒。
对于首次运行,我们以非常可重复的模式每60秒注入一次故障。最初集群中有一台坏机器,因此我们暂时将世界规模缩小到25台主机,直到机器被替换,然后我们在零停机的情况下将作业扩展回来。
以每60秒一次的故障率,我们预计能够在每次故障之间无问题地完成约5步。查看结果,我们发现有6249步和5145次成功提交。torchft被设计为尽可能安全,如果发生任何错误,它将在运行优化器之前通过“should_commit”丢弃该步骤。
对于整体步骤效率,我们有
5145个成功步骤 / 6249个总步骤 = 82.3%
以约11秒的步长时间和每60秒一次的故障率,我们应该能够完成每6步中的5步(83.3%),这与测量性能几乎完全匹配。
我们平均每步有29.6个参与的副本组,因此这次训练的总效率为81.2%。对于超过1000次故障来说,这还不错。
运行2:每15秒注入一次故障,共1015次故障
我们想看看能将此推到多远,并使其更具挑战性。在第二次运行中,我们在0-30秒之间注入故障,平均每15秒一次。
与训练作业通常每10分钟到几小时的平均故障间隔时间相比,这种故障率是极端的,但这让我们能够验证无论何时发生错误都能恢复,并让我们进行大量的测试循环,以增强对我们实现的信心。
通过随机化故障间隔,我们使得故障发生在工作节点仍在初始化而不是处于稳定状态时,并且更有可能触及边缘情况。我们很高兴地报告,torchft表现如预期,没有出现不可恢复的错误。
如您所见,这个作业表现得更加不稳定。与60秒故障率下非常接近30台机器相比,在每15秒一次故障的情况下,每一步可用的机器数量从1台到30台不等。
平均而言,我们在任何给定步骤上都有18.9个(18.9/30 = 63%)工作节点健康并参与,平均步长为15.46秒。
在前888个步骤中,有268个步骤成功提交,这使我们的步长效率达到30.2%。
这使得我们的训练效率仅为13.4%,这在任何正常训练作业中都是糟糕的,但令人惊讶的是,即使每15秒发生一次崩溃,模型仍在收敛!从检查点加载模型通常需要超过1分钟的时间。
损失收敛速度比我们60秒MTBF的运行慢,但这是预期的,因为有更多的批次由于错误而被丢弃。
我们确实看到损失出现一些更大的尖峰,这与只有一个参与者健康并因此只有1/30批量大小的时候相关。这可以通过调整最小副本数轻松避免。我们在这次测试中将其设置为1。
运行3:半同步训练
TorchFT还支持半同步训练算法,包括LocalSGD和DiLoCo,并计划在未来添加更多。与HSDP2不同,这些算法不是每一步都同步。相反,它们会进行几步的局部训练,然后通过平均参数或梯度来同步权重。这种方法通过将通信成本降低到每N步一次(一个可配置的超参数)而不是每一步,从而提高了性能。我们在集群上的测试表明吞吐量有显著提高。当每40步同步一次时,我们最大限度地减少了通信开销,从而实现了更高的整体吞吐量。下面是DiLoCo(黄色)和常规HSDP2(紫色)的吞吐量比较,DiLoCo平均约为4000 tps,而HSDP2平均约为1200 tps。
当然,同步间隔越长,副本组内模型的分歧就越大。这种分歧可能会影响模型的收敛。然而,在我们的测试中,我们观察到尽管同步间隔较长,模型仍然能够有效地训练并达到收敛。这种韧性在副本可能意外离开组的动态环境中非常有用。即使在这种情况下,模型也表现出能够继续训练而没有显著中断的能力。
后续步骤
torchft 正在积极开发中,我们计划在新的算法上进行大量改进,例如流式 DiLoCo,使 PyTorch 分布式对故障(甚至在 Infiniband/NVLink 上!)更具鲁棒性,并提高效率。
如果您对使用 torchft 感兴趣,请查看torchft README和torchft 文档。我们也乐意与您交流,请随时通过 GitHub、LinkedIn 或 Slack 直接联系我们。