博客

TorchSpec:大规模投机解码训练

引言

在过去的一年里,大语言模型在规模和能力上都经历了快速扩展。Kimi K2.5、GLM 5 和 Qwen 3.5 等前沿模型现在拥有数千亿参数和长达数百万 token 的上下文窗口,支持长文本推理、代理工作流和复杂的工具使用。随着这些模型能力不断增强,高效推理已成为大模型部署中最重要的系统挑战之一。

推测解码(Speculative Decoding)是加速大语言模型生成最有效的技术之一。在推测解码中,轻量级的草稿模型(Draft Model)先行预测多个 token,随后由较大的目标模型(Target Model)在一次前向传播中进行验证。当预测被采纳时,可以一次性生成多个 token,从而显著提高吞吐量并降低延迟。近期出现的如 MTP(多 Token 预测)和 EAGLE-3 等方法证明,经过良好训练的草稿模型能够带来持续的性能提升。

草稿模型训练的一个关键环节是通过中间隐藏状态(Hidden States)将信息从目标模型转移到草稿模型。随着前沿大模型规模不断扩大,引入了新的系统瓶颈:如何高效地将目标模型生成的海量隐藏状态传递给草稿模型。例如,EAGLE-3 依赖目标模型 3 层的隐藏状态。在训练 Kimi K2.5 的 EAGLE-3 草稿模型时,单个 128K token 的训练样本就需要约 7 GB 的隐藏状态。在数据集规模下,这会变得极其庞大。

现有的方案通常采用以下两种方式之一。一种是预先计算并存储隐藏状态到磁盘,这会导致极高的存储需求和严重的 I/O 压力。另一种是将推理与训练并置(Co-location),在草稿模型训练过程中实时生成隐藏状态,虽然避免了磁盘存储,但要求目标模型与训练工作节点并置,这会产生巨大的 GPU 显存压力。

为了解决这些挑战,我们推出了 TorchSpec,这是一个用于解耦推测解码训练的 PyTorch 原生框架。TorchSpec 将生成隐藏状态的推理系统与消耗这些状态的训练系统分离开来。隐藏状态无需写入磁盘,而是通过中央 Mooncake 存储,利用 RDMA(远程直接内存访问)或 TCP 直接从推理引擎组流向训练工作组。这种设计消除了磁盘存储需求,并允许推理和训练资源独立扩展。

利用 TorchSpec,我们成功训练了一个 Kimi K2.5 EAGLE-3 草稿模型,耗时 1500 个 H200 GPU 小时,涵盖 60 万条训练样本,共计 60 亿个 token。该草稿模型在多项基准测试中表现出色:

*草稿模型训练时 lookahead=4

在使用 3 个 token 的 lookahead 设置下,草稿模型使输出吞吐量在 batch size 为 1 时提升超过 60%,batch size 为 8 时提升 30%,batch size 为 16 时提升 26%。

背景

目前主流的推测解码训练方法有两种:

  • 推理训练并置方案
  • 离线隐藏状态准备方案

两者在中等规模下有效,但在草稿模型规模和上下文长度增加时会面临严峻挑战。

推理训练并置方案

在并置训练中,目标模型和草稿模型共享相同的 GPU。目标模型运行前向传播以产生隐藏状态和 Logits,随后草稿模型立即消耗这些数据进行训练。由于目标模型和草稿模型紧密耦合,该方案存在多项局限性:

  • 分片刚性:草稿模型的并行策略必须与目标模型绑定。例如,如果目标模型使用 TP=4,则草稿模型也必须使用 4 个 Rank,即使该配置对其较小的架构而言并非最优。
  • 无法独立扩展:目前的并置框架通常缺乏跨节点分片支持,将训练限制在单节点的 GPU 内。更重要的是,推理和训练资源被绑定,无法单独配置。
  • 内存压力:目标模型占用了巨大的 GPU 显存,导致用于训练草稿模型的显存非常有限。

并置训练的内存分析(以 Kimi K2.5 为例:1 万亿参数 MoE 模型,384 个专家,模型权重约 575 GB):

GPU 总显存(8 GPU) 模型权重 每 GPU 分片 剩余显存
8×H200 1,128 GB ~575 GB ~72 GB ~69 GB
8×H100 640 GB ~575 GB ~72 GB ~8 GB

虽然草稿模型通常很小,但诸如 Training-Time Testing (TTT) 等先进训练方法需要很高的内存,因为它保留了多个推测步骤的中间激活值。激活值的积累推高了整体显存开销。在 8 GB 剩余显存的情况下,我们只能进行 4096 上下文长度的训练。

离线隐藏状态准备方案

离线方案预先计算目标模型的隐藏状态,序列化并存储到磁盘,然后在训练时加载。这虽然实现了推理与训练的解耦,但引入了显著的存储挑战,特别是对于长上下文的大模型而言。

Kimi K2.5 存储分析(hidden_size=7168, vocab_size=163,840)

单样本(上下文长度 = 131,072 token)

张量 形状 数据类型 大小
隐藏状态 (3 个辅助层) (131072, 21504) bf16 5.25 GB
最后隐藏状态 (131072, 7168) bf16 1.75 GB
输入 ID (131072,) int64 1 MB
每条样本总计 ~7.0 GB

注意:目标 Logits 可以通过 lm_head 从最后隐藏状态重新计算,因此无需存储。即便如此,存储需求也会迅速膨胀。

数据集大小 存储需求
10,000 样本 70 TB
30,000 样本 210 TB
100,000 样本 700 TB

在此规模下,分布式文件系统面临巨大压力,特别是在多个推测训练任务并发运行时,它们会争抢 I/O 带宽。序列化和反序列化的开销也极大地拖慢了训练速度。

TorchSpec:解耦草稿模型训练

TorchSpec 采用了不同的策略:完全解耦的推理与训练。目标模型在专用的推理 GPU 上运行,草稿模型在独立的训练 GPU 上进行训练,张量数据通过 Mooncake 存储利用高速网络协议 RDMA 或 TCP 在两者之间传输。

该架构解决了上述核心挑战:

  1. 灵活且独立的扩展。推理和训练 GPU 数量完全独立,允许配置更多的推理引擎以获得更高的隐藏状态生成吞吐量,或添加更多的训练 GPU 以支持更大的 FSDP 分片和全局 batch size。
  2. 完整的训练显存。训练 GPU 完全专用于草稿模型,最大化了处理长序列和大批量数据所需的可用显存。
  3. 无存储开销。隐藏状态通过 RDMA/TCP 从推理侧直接流向训练侧。数据无需下沉到磁盘,消除了文件系统压力和序列化成本。

为什么选择 Mooncake?

Mooncake 最初由月之暗面(Moonshot AI)和清华大学开发,是一款专为生产环境 LLM 服务中的 KV Cache 管理而设计的传输引擎。它现已演进为 PyTorch 生态中蓬勃发展的社区项目。Mooncake 能够通过多种网络协议处理高吞吐量的跨节点数据传输,并管理内存生命周期。这些正是 TorchSpec 高效、可靠地将隐藏状态从推理 GPU 传输到训练 GPU 所需的核心功能。

使 Mooncake 成为不二之选的关键特性:

  • 统一 API 支持 RDMA + TCP。在 InfiniBand/RoCE 集群上实现近线速传输;在无 RDMA 时可直接通过 TCP 工作,无需修改代码。
  • GPU Direct RDMA。直接将数据传输到 GPU 显存,绕过 CPU 中转——这对每个训练样本包含数 GB 隐藏状态的场景至关重要。
  • 零拷贝传输。张量直接打包至预注册的 pinned-memory 缓冲区中传输,无需序列化或中间拷贝。
  • 生产级可靠性。经过大规模生产部署的实战考验,为 TorchSpec 长期运行的多节点训练提供了坚实基础。

长上下文支持

由于显存完全专用于草稿模型,TorchSpec 支持在 EAGLE-3 训练中通过并置方案无法实现的长序列长度。例如,Kimi K2.5 在并置训练中消耗 72 GB 显存。在 lookahead 为 4 的解耦训练方案下,单个 H100 GPU 可支持高达 44K token 的序列训练,单个 B200 GPU 可扩展至 200K token。

除了解耦,TorchSpec 还采用了原生推理引擎实现:隐藏状态直接由生产环境的推理引擎生成。这一设计选择带来了两个关键优势:

  • 推理-训练对齐:模板格式化、分词和内核完全对齐。训练环境与部署环境之间不存在偏差。
  • 通过引擎支持原生模型:在训练侧支持新目标模型架构只需极小的改动。目前 TorchSpec 已支持 vLLM 和 SGLang,TensorRT LLM 的支持也即将推出。只要推理引擎支持该模型,TorchSpec 就可以直接为其训练草稿模型。这包括:
    • 新模型架构(MoE、多模态等)
    • 量化模型(FP8、INT4 等)
    • 稀疏注意力机制、RoPE(旋转位置编码)变体以及其他模型特有功能

Train with Decode(生成式训练)

当草稿模型在目标模型的 Token 分布上进行训练时,性能通常最佳。一种常见做法是保留数据集的原始提示词,并由目标模型重新生成回复作为训练前的预处理。然而,这种双阶段流程对研究人员和工程师来说非常繁琐。通过我们的引擎原生设计,我们可以在训练过程中从纯提示词输入自动回归地生成输出。

案例研究:为 Kimi K2.5 训练 EAGLE-3 模型

Kimi K2.5 提供了一个极具挑战性的训练场景,展现了解耦方案的价值。

挑战

  • 模型规模:Kimi K2.5 仅部署目标模型就需要至少 8×H200 或 16×H100 GPU,如果并置训练,留给草稿模型的显存极度匮乏。
  • 长上下文:Kimi K2.5 针对长上下文代理和推理工作负载,需要支持多达 200,000 个 token 的序列训练。
  • 大词汇表:拥有 163,840 个 token 的词汇表和 7,168 的隐藏维度。

TorchSpec 解决方案

利用 TorchSpec,我们建议使用 8×H200 GPU 部署 Kimi K2.5 作为专用推理引擎,并使用另外 8×H200 GPU 训练 EAGLE-3 草稿模型。推理集群拥有完整的显存用于服务和生成隐藏状态;训练集群则拥有完整的 GPU 显存用于草稿模型,使得在 60 万条数据样本下进行 100,000 token 的长上下文训练成为可能。

脚本:我们提供了两个开箱即用的脚本用于训练 Kimi K2.5 草稿模型:

– 3 节点 8xH100 (TP=16 推理,TP=8 训练):kimi-k25-3node-h100

– 2 节点 8xH200 (TP=8 推理,TP=8 训练):kimi-k25-2node-h200

训练数据集:我们开源了混合数据集:kimi-600k-training-dataset

草稿模型:我们开源了草稿模型:kimi-k2.5-eagle3

路线图

TorchSpec 正处于活跃开发阶段,我们重点关注以下领域:

  • 提升模型覆盖范围:计划支持更多热门模型,如 Minimax M2.5、Qwen 3.5,并支持 GLM 5 的 MTP 层持续训练。
  • 打包序列训练(Packed sequence training):将多个较短序列打包成单个训练样本,以最大化 GPU 利用率并减少填充浪费,特别是针对变长输入的数据集。
  • 扩展训练算法:超越 EAGLE-3,支持如 DFlash、MTP 等其他推测解码训练方法,扩充 TorchSpec 可训练的草稿模型架构范围。
  • 引擎集成:与更多流行推理引擎(如 TensorRT LLM)集成,以便用户能够将其与最适合的部署栈相连。

致谢

我们衷心感谢以下团队和合作者:

TorchSpec 团队及社区:*Yubo Wang, *Yinghui Liu, Shirley Wu, Junxiong Wang, Qingyang Wu, Bobbie Bie, Fan Yin, Chao Wang, Weicong Wu, Jue Wang

Mooncake 团队:*Jiaqi Liao, Mingxing Zhang