作者:IBM PyTorch 团队

推测解码是一种推理优化技术,它在生成当前 token 的同时,在一个前向传递中对未来的 token 进行有根据的猜测。它包含一个验证机制,以确保这些推测 token 的正确性,从而保证推测解码的整体输出与普通解码的输出相同。优化大型语言模型 (LLM) 的推理成本,可以说是降低生成式 AI 成本和提高其普及度的最关键因素之一。为了实现这一目标,有多种推理优化技术可用,包括自定义内核、输入请求的动态批处理以及大型模型的量化。

在这篇博文中,我们将提供推测解码指南,并演示它如何与其他优化技术共存。我们很荣幸能开源以下内容,其中包括首个针对 Llama3 模型的推测器:

  1. 适用于 Meta Llama3 8BIBM Granite 7B labMeta Llama2 13BMeta Code Llama2 13B 的推测器模型。
  2. 通过 IBM 基于 Hugging Face TGI 的分支进行推理的代码。
  3. 用于训练您自己的推测器及相应秘籍的代码。

我们已将这些推测器部署到内部生产级环境中,日活跃用户达数千人,并观察到在语言模型(Llama3 8B、Llama2 13B 和 IBM Granite 7B)上实现了 2 倍加速,在 IBM 的 Granite 20B 代码模型上实现了 3 倍加速。我们在此 技术报告 中提供了方法的详细解释,并计划在即将发表的 ArXiv 论文中进行深入分析。

推测解码:推理

我们在内部生产环境中运行 IBM TGIS,该环境包含连续批处理、融合内核和量化内核等优化措施。为了在 TGIS 中启用推测解码,我们修改了来自 vLLM 的分页注意力内核。接下来,我们将描述为启用推测解码而对推理引擎进行的关键修改。

推测解码基于模型足够强大,能够在单个前向传递中预测多个 token 的前提。然而,当前的推理服务器优化后每次只能预测一个 token。在我们的方法中,我们在 LLM 上附加了多个推测头(speculative heads)(除了通常的一个),以预测第 N+1、N+2、N+3 等 token。例如,3 个头将预测 3 个额外的 token。推测器架构的详细信息将在本博文的后面部分解释。在推理过程中实现效率和正确性面临两个挑战——一个是在不复制 KV 缓存的情况下进行预测,另一个是验证预测与原始模型的输出是否匹配。

在典型的生成循环中,在单个前向步骤处理完提示词后,一个序列长度为 1(预测的下一个 token)与 KV 缓存一起被送入模型的前向传递。在朴素的推测解码实现中,每个推测头会有自己的 KV 缓存,但我们instead,修改了 vLLM 项目中开发的分页注意力内核,以实现高效的 KV 缓存维护。这确保了在更大的批处理大小时吞吐量不会降低。此外,我们修改了注意力掩码,以实现对第 N+1 个 token 的验证,从而在不偏离原始模型输出的情况下启用推测解码。此实现的详细信息在此处 提供

结果

我们使用一个简单的提示词,展示了使用 Meta 的 Llama2 13B 对话版本获得的加速效果。

Visual illustration of the non-speculative generation (left) compared to speculative generation (right)

图 2:非推测生成(左)与推测生成(右)的可视化对比

我们将上述解决方案部署到内部生产环境中。下图报告了两个指标——首 token 时间 (TTFT) 和 token 间隔延迟 (ITL),图中曲线上的数字表示不同的并发用户数。我们观察到,对于所有批处理大小,推测解码版本在 Llama2 13B 对话模型上几乎是双倍加速,在 Granite 20B 代码模型上几乎是三倍加速,而与非推测版本相比。我们对较小模型——IBM 的 Granite 7B 和 Meta Llama3 8B 模型——也观察到了类似的行为。

Time to first token (TTFT - left) and Inter-token latency (ITL - right) for Llama 13B with number of concurrent users indicated on the graph

图 3:Llama 13B 的首 token 时间(TTFT - 左)和 token 间隔延迟(ITL - 右),图中标注了并发用户数

Time to first token (TTFT - left) and Inter-token latency (ITL - right) for Granite 20B Code with number of concurrent users indicated on the graph

图 4:Granite 20B Code 的首 token 时间(TTFT - 左)和 token 间隔延迟(ITL - 右),图中标注了并发用户数

关于效率的说明

我们进行了大量实验来确定推测器训练的正确配置。这些包括:

  1. 推测器架构:当前方法允许修改头的数量,这对应于我们可以向前看的 token 数量。增加头的数量也会增加所需的额外计算量和训练的复杂性。实践中,对于语言模型,我们发现 3-4 个头效果良好,而代码模型可以从 6-8 个头中获益。
  2. 计算量:增加头的数量会导致计算量在两个方面增加,一是单个前向传递的延迟增加,二是多个 token 所需的计算量增加。如果推测器在更多头的情况下不够准确,将导致计算资源浪费,增加延迟并降低吞吐量。
  3. 内存:增加的计算量会被每次前向传递所需的高带宽内存 (HBM) 往返次数抵消。请注意,如果我们成功预测 3 个 token,我们就能节省三次 HBM 往返时间。

我们最终确定语言模型使用 3-4 个头,代码模型使用 6-8 个头。对于 7B 到 20B 不同大小的模型,我们观察到与非推测解码相比,延迟显著改善且吞吐量没有损失。当批处理大小超过 64 时,我们开始观察到吞吐量下降,这在实践中很少发生。

推测解码:训练

推测解码有两种主要方法,一种是利用较小的模型(例如,使用 Llama 7B 作为 Llama 70B 的推测器),另一种是附加推测头(并对其进行训练)。在我们的实验中,我们发现附加推测头的方法在模型质量和延迟收益方面都更有效。

推测器架构

Medusa 使推测解码流行起来;他们的方法是在现有模型上添加一个头,然后训练它进行推测。我们通过使“头”具有层次结构来修改 Medusa 架构,其中每个头阶段预测一个 token,然后将其馈送到下一个头阶段。这些多阶段头在下图中描绘。我们正在探索通过在多个阶段和基础模型之间共享嵌入表来最小化它的方法。

A simple architecture diagram for a 3-headed multi-stage  speculator. Z is the state from the base model.

图 4:一个简单的 3 头多阶段推测器架构图。Z 是来自基础模型的状态。

推测器训练

出于效率考虑,我们采用两阶段方法来训练推测器。在第一阶段,我们在具有长序列长度(4k token)的小批次上进行训练,并使用标准的因果语言模型方法进行训练。在第二阶段,我们使用由基础模型生成的具有短序列长度(256 token)的大批次。在此训练阶段,我们调整头以匹配基础模型的输出。通过大量实验,我们发现第一阶段与第二阶段的步骤比例为 5:2 效果良好。我们在下图中描绘了这些阶段的进展。我们使用 PyTorch FSDP 和 IBM FMS 进行推测器的训练。

Per-head training loss curves for Llama2-13B speculator training, phase 1 and 2

图 5:Llama2-13B 推测器训练的每个头训练损失曲线,阶段 1 和 2

结论与未来工作

通过这篇博文,我们发布了一种新的推测解码方法以及以下资产:

  1. 用于改善一系列模型(Llama3 8B、Llama2 13B、Granite 7B 和 CodeLlama 13B)的 token 间隔延迟的模型
  2. 用于推理的生产级代码
  3. 用于训练推测器的秘籍

我们正在努力训练适用于 Llama3 70B 和 Mistral 模型的推测器,并邀请社区贡献力量并帮助改进我们的框架。我们也乐于与 vLLMTGI 等主要的开源服务框架合作,将我们的推测解码方法贡献回去,造福社区。

致谢

有几个团队帮助我们在推理延迟方面取得了这些改进。我们要感谢 vLLM 团队以清晰且可重用的方式创建了分页注意力内核。我们感谢 Meta 的 PyTorch 团队,他们对这篇博文提供了反馈,并持续努力优化 PyTorch 的使用。特别感谢 IBM Research 的内部生产团队,他们将此原型投入生产并使其更加稳健。特别鸣谢 Stas Bekman 对博文提供的富有洞察力的评论,这些评论帮助改进了计算、内存和推测器有效性之间权衡的解释。

分页注意力内核由 Josh Rosenkranz 和 Antoni Viros i Martin 集成到 IBM FMS 中。推测器架构和训练由 Davis Wertheimer, Pavithra Ranganathan 和 Sahil Suneja 完成。建模代码与推理服务器的集成由 Thomas Parnell, Nick Hill 和 Prashant Gupta 完成。