合作者:Less Wright, Howard Huang, Chien-Chin Huang,Crusoe 团队:Martin Cala, Ethan Petersen
摘要:我们利用 torchft 和 torchtitan 在真实环境中进行了极限合成故障率下的模型训练,旨在证明容错训练的可靠性和正确性。

在经历 1200 次无检查点故障的情况下的训练损失变化。
注意:每一个小的尖峰都是一个非参与工作节点正在恢复,这会影响指标数据,但不会影响模型本身。
引言
我们希望通过运行极端故障率的训练任务,在最恶劣的情况下展示 torchft 的性能。
大多数 LLM 预训练使用基于 FSDP 的分片模型。torchft 通过 HSDP2 支持分片模型,它结合了分片模型与 torchft 的容错 DDP 全归约(all-reduce)。我们将 torchft 集成到了 torchtitan 中,因此您可以开箱即用。torchft+titan 还支持每个副本组内的其他分片/并行化方式,例如张量并行 (TP)、流水线并行 (PP) 等。
以下是使用 torchft 进行训练任务的结构:

训练任务的结构。torchft 的容错 DDP 实现用于跨副本组同步梯度。每个副本组内使用标准的 FSDP2 和其他并行化技术。

torchft 使用全局 Lighthouse 服务器和每个副本组的管理程序(Managers)进行工作节点的实时协调。Lighthouse 通过心跳机制了解所有工作节点的状态及健康状况。
torchft 实现了几种不同的容错算法。其中最主要的两种是:
- 容错 HSDP:FSDPv2 的扩展,使用容错全归约(all-reduce)。这完全模拟了标准 HSDP 训练,即每一步进行梯度 all_reduce 和每一步容错。它最适合在具有快速后端网络(如 Infiniband)的大规模训练中使用。
- LocalSGD/DiLoCo:半同步训练的容错实现。这些算法通过在指定间隔(而不是像 HSDP 那样每一步)同步来最小化通信开销。这通常用于受通信限制的训练场景,如以太网/TCP 或地理位置分散的环境(联邦学习或多数据中心训练)。
我们一直在关注新算法,例如即将推出的流式 DiLoCo 支持。如果您有新的用例并希望与我们合作,请联系我们!
集群设置
Crusoe 慷慨地为我们提供了 300 块 L40S GPU 的集群。这些 GPU 分布在 30 台主机上,每台主机配备 10 块 NVIDIA L40S GPU。
模型方面,我们使用了 torchtitan 和一个具有 10 亿参数的 Llama 3 模型,以匹配现有的硬件配置。
NVIDIA L40S GPU 通常用于推理,这为我们在非传统环境下测试 torchft 提供了机会,在这些环境中,由于 TCP 网络瓶颈(无 Infiniband/nvlink),DiLoCo 等算法的优势尤为明显。L40S 拥有 48GB 显存(接近消费级 GPU),因此我们使用了较小的模型和批大小。训练平均每步耗时约为 9 秒。
为了在有限的网络条件下最大化性能,我们采用了 30x1x10 的配置。我们设置了 30 个副本组(容错域),每个组包含 1 台主机和 10 个 GPU/工作节点。torchft 的每个副本组可以包含许多主机,但对于该集群,由于网络带宽限制,每副本组单主机/10 GPU 的配置表现最佳。我们运行了 30 个副本组,因为组数越多,对协调和重新配置算法的压力就越大。
在网络通信方面,我们对每个副本组内的通信使用 NCCL(如 FSDP),并对跨副本组的通信使用 Gloo。Gloo 虽然性能往往稍弱,但初始化速度快得多,故障检测速度也更快,这对快速识别故障至关重要。torchft 确实支持在 IB 集群上使用 NCCL 进行容错,但本次演示未使用。由于我们希望最大化总故障和恢复次数,我们选用了 Gloo,因为它在我们的用例中可以在 1 秒内完成重新初始化,并且我们可以将所有操作的超时时间设置为 5 秒。
在容错算法方面,我们主要使用容错 HSDP 进行了大量测试,因为它对通信和法定人数(quorum)层的压力最大。在最终测试中,我们使用了更适合以太网集群的 DiLoCo。
无需检查点的恢复
传统的机器学习通过在错误发生时从检查点重载来实现“容错”。这涉及到完全停止作业(stop-the-world),所有工作节点必须重启并从最近持久化的检查点加载数据。
使用 torchft,我们专注于将故障隔离到单个 GPU 组。当该组内发生错误时,我们可以异步重启该组,而其他组可以重新配置并继续训练,无需该组参与。
当该组通过重启或调度器替换机器恢复时,这些工作节点不再拥有权值和优化器状态的有效副本。如果我们尝试从检查点恢复,其他组早已继续前行。相反,我们依赖运行时的异步权值传输。这通过从健康的副本进行点对点(P2P)传输来实现。
由于我们始终从另一个工作节点进行恢复,事实证明,只要能保证至少有一个组是健康的,我们根本不需要任何检查点。在本次演示中,我们完全关闭了检查点功能,因为持久化保存和加载检查点的时间远长于我们的 P2P 恢复时间。
下图展示了正在恢复的副本(副本 1)如何加入法定人数并从健康的对等节点(副本 0)恢复,且无需停机或影响健康工作节点的训练。

torchft 借鉴了分布式数据库中的许多概念。
- 法定人数(Quorum)操作通过频繁的心跳确定哪些工作节点是健康的,并保证我们可以快速识别存活节点、以容错方式交换元数据,并强制执行无脑裂(split-brain)条件。
- 为了确保一致性并确定何时需要恢复工作节点,我们将训练视为具有传统数据库语义的任务。传统数据库使用“事务”,即每个操作要么提交(完全应用),要么回滚(丢弃)。torchft 对每个训练步骤的处理方式相同。副本组内的每个训练步骤都被视为一个分布式事务,确保所有工作节点在执行优化器步骤时提交该步,如果发生错误,它们都会通过丢弃梯度进行回滚。
有关更多详细信息,请参阅 torchft README,其中包含文档、设计文档和演示文稿的链接。
训练循环集成
TorchFT 已经与 TorchTitan 集成,因此启用它只需设置一个配置标志。对于典型模型,torchft 提供了包装器,会自动调用 TorchFT 管理器的钩子来提供容错功能。
from torchft import Manager, DistributedDataParallel, Optimizer, ProcessGroupGloo # Instantiate your model and optimizer as normal m = nn.Linear(2, 3) optimizer = optim.AdamW(m.parameters()) # Setup torchft Manager and wrap the model and optimizer. manager = Manager( pg=ProcessGroupGloo(), load_state_dict=lambda state_dict: m.load_state_dict(state_dict), state_dict=lambda: m.state_dict(), ) m = DistributedDataParallel(manager, m) optimizer = Optimizer(manager, optimizer) for batch in dataloader: # When you call zero_grad, we start the asynchronous quorum operation # and perform the async weights recovery if necessary. optimizer.zero_grad() out = m(batch) loss = out.sum() # The gradient allreduces will be done via torchft's fault tolerant # ProcessGroupGloo wrapper. loss.backward() # The optimizer will conditionally step depending on if any errors occured. # The batch will be discarded if the gradient sync was interrupted. optimizer.step()
容错调度
由于副本组内工作节点的语义与普通作业相同,我们可以使用 Slurm 等标准 ML 作业调度器。如果组内任何工作节点发生错误,我们期望整个组同时重启。在每个副本组内,应用程序是一个使用标准非容错操作的完全标准的训练作业。
为了在传统调度器上实现容错,我们运行多个这样的作业。每个副本组在 Slurm 上作为一个单独的训练作业运行,Lighthouse 和监控脚本在主节点上运行。所有跨组通信均通过 torchft 的托管 ProcessGroup 和法定人数 API 进行。我们使用一个小型的 torchx Python API 脚本来在故障时重启组并注入故障。
监控脚本如下所示:
from torchx.runner import get_runner NUM_REPLICA_GROUPS = 30 with get_runner() as runner: while True: jobs = runner.list(scheduler) active_replicas = { parse_replica_id(job.name) for job in jobs if not job.is_terminal() } missing_replicas = set(range(NUM_REPLICA_GROUPS)) - active_replicas for replica_id in missing_replicas: app_def = make_app_def(replica_id=replica_id) app_handle = runner.run( app_def, scheduler="slurm", cfg={"partition": "batch"}, ) print("launched:", replica_id, app_handle) time.sleep(5.0)
故障是通过使用 `scancel` 取消特定副本组的 Slurm 作业来注入的。在现实世界中,我们期望故障是由训练过程中的错误触发的,这将导致该副本组在隔离状态下崩溃,而不是发生外部故障。
指标与日志
为了确保对作业有统一的视角,我们避免在一个副本组内注入故障,以便更轻松地追踪作业的指标和法定人数事件。该组能够持续记录参与者数量、步骤成功/失败情况以及损失值。
由于我们执行的是每步容错,参与者数量(即批大小)会根据哪些工作节点健康而每步变化。
损失值通过跨副本组的 all-reduce 对所有工作节点/副本组进行平均。
注意:损失图中出现的小尖峰是由于我们如何计算所有主机(包括正在恢复、权值过期的工作节点)的平均损失,这会导致这些步骤的损失值出现错误的偏高。
运行情况
我们进行了三次不同的运行,展示了各种故障场景和 torchft 的特性。
运行 1:每 60 秒注入一次故障,共 1100 次故障

这次运行持续了超过 19 小时,共 6249 个步骤。平均每个步骤耗时 10.9 秒。
在初始运行中,我们以非常可重复的模式每 60 秒注入一次故障。最初集群中有一台坏机器,因此我们将总世界大小暂时缩减为 25 台主机,直到机器被更换,之后我们又以零停机时间将作业规模扩展回去。
由于每 60 秒发生一次故障,我们预计在每次故障之间可以完成约 5 个步骤而不会有任何问题。查看结果,共有 6249 个步骤,其中 5145 次成功提交。torchft 的设计旨在尽可能安全,如果发生任何错误,它会在运行优化器之前通过“should_commit”丢弃该步骤。
关于总体步骤效率,我们有:
5145 次成功步骤 / 6249 次总步骤 = 82.3%
步骤耗时约 11 秒,每 60 秒发生一次故障,我们应该能够完成每 6 个步骤中的 5 个(83.3%),这与测量性能几乎完全匹配。
我们平均每步有 29.6 个参与副本组,因此总训练效率为 81.2%。对于超过 1000 次故障来说,这已经很不错了。
运行 2:每 15 秒注入一次故障,共 1015 次故障
我们想看看能将这项技术推向何种程度,并增加难度。在第二次运行中,我们随机在 0-30 秒之间注入故障,平均每 15 秒发生一次。
这种故障率与训练作业相比极其极端(通常平均故障间隔时间在数十分钟到数小时之间),但这让我们能够验证无论错误何时发生都能恢复,并允许我们运行大量测试周期以增强对实现的信心。
通过随机化故障间隔,我们使得故障在工作节点仍在初始化时发生,而不是在稳态时发生,这更容易触及边缘情况。我们很高兴地报告,torchft 的表现符合预期,没有出现不可恢复的错误。

如您所见,这项作业的表现非常不稳定。与 60 秒故障率时接近 30 台机器的情况不同,在每 15 秒一次故障的情况下,每一步可用的机器数量从 1 台到 30 台不等。
平均而言,我们在任何给定步骤中有 18.9 个(18.9/30 = 63%)健康的参与工作节点,平均步耗时为 15.46 秒。
在前 888 个步骤中,有 268 个步骤成功提交,步骤效率为 30.2%。
这使得我们的训练效率仅为 13.4%,在任何正常的训练作业中这都很糟糕,但令人瞩目的是,尽管每 15 秒崩溃一次,模型仍在收敛! 仅从检查点加载模型通常就需要超过 1 分钟的时间。
损失收敛速度比我们的 60 秒平均故障间隔时间(MTBF)运行慢,这是预料之中的,因为有更多的批次因错误而被丢弃。
我们确实看到损失值出现了一些较大的尖峰,这些尖峰与只有 1 个参与者处于健康状态的时间点相关(因此批大小仅为 1/30)。通过调整最少副本数可以轻松避免这种情况。我们在测试中将其设置为 1。
运行 3:半同步训练
TorchFT 还支持半同步训练算法,包括 LocalSGD 和 DiLoCo,并计划在未来增加更多。与 HSDP2 不同,这些算法不会在每一步进行同步。相反,它们在同步权值(通过平均参数或梯度)之前进行若干步的本地训练。这种方法通过将通信成本降低到每 N 步一次(可配置的超参数)而不是每一步一次,从而提高了性能。我们在集群上的测试表明吞吐量有了显著提升。当每 40 步同步一次时,我们最小化了通信开销,从而获得了更高的整体吞吐量。下图是 DiLoCo 吞吐量(黄色,平均约 4000 tps)与常规 HSDP2(紫色,平均约 1200 tps)的比较。

自然地,同步间隔越长,副本组内的模型就会产生越大的分歧。这种分歧可能会影响模型的收敛。然而,在测试中,我们观察到尽管同步间隔较长,模型仍然能够有效地进行训练并达到收敛。这种韧性在副本可能意外离开组的动态环境中非常有用。即使在这种情况下,模型也展示了无需重大中断即可继续训练的能力。

下一步
torchft 正处于积极开发阶段,我们计划在流式 DiLoCo 等新算法方面进行大量改进,使 PyTorch Distributed 对故障更具鲁棒性(即使在 Infiniband/nvlink 上也是如此!),并实现更高的效率。
如果您有兴趣使用 torchft,请查阅 torchft README 和 torchft 文档。我们也期待与您交流,请随时通过 GitHub、LinkedIn 或 Slack 直接联系我们。