博客

推测解码搭车客指南

作者: 2024年5月2日2024年11月13日暂无评论

推测解码(Speculative decoding)是一种推理优化技术,它能在生成当前 Token 的同时,在单次前向传递中对后续 Token 进行有根据的预测。该技术结合了验证机制以确保这些推测 Token 的准确性,从而保证推测解码的最终输出与常规解码完全一致。优化大语言模型(LLM)的推理成本是降低生成式 AI 成本并提高其应用率的最关键因素之一。为实现这一目标,目前已有多种推理优化技术,包括自定义内核、输入请求的动态批处理以及大模型的量化。

在本篇博文中,我们将提供一份推测解码指南,并演示其如何与其他优化技术共存。我们很自豪地开源以下内容,其中包括首个针对 Llama3 模型的推测器(Speculator):

  1. 适用于 Meta Llama3 8BIBM Granite 7B labMeta Llama2 13BMeta Code Llama2 13B 的推测器模型。
  2. 通过 IBM 分支的 HF TGI 进行推理的代码。
  3. 用于训练您自己的推测器及其对应方案的代码。

我们已将这些推测器部署在拥有数千名日常用户的内部生产级环境中,并观察到语言模型(Llama3 8B、Llama2 13B 和 IBM Granite 7B)的推理速度提升了 2 倍,IBM 的 Granite 20B 代码模型的推理速度提升了 3 倍。我们在这份技术报告中详细解释了我们的方法,并计划在即将发表的 ArXiv 论文中进行深入分析。

推测解码:推理

我们在内部生产环境中运行 IBM TGIS,该环境具备连续批处理(continuous batching)、融合内核(fused kernels)和量化内核等优化措施。为了在 TGIS 中启用推测解码,我们修改了来自 vLLM 的分页注意力(paged attention)内核。接下来,我们将介绍为启用推测解码而对推理引擎所做的关键更改。

推测解码的前提是模型足够强大,能够在单次前向传递中预测多个 Token。然而,目前的推理服务器仅优化为一次预测一个 Token。在我们的方法中,我们向 LLM 附加了多个推测头(除了常规的一个之外),以预测 N+1、N+2、N+3 ... 个 Token。例如,3 个头将预测 3 个额外的 Token。推测器的架构细节将在本文的后续部分解释。在推理过程中要实现效率正确性面临两个挑战:一是如何在不复制 KV 缓存的情况下进行预测,二是如何验证预测结果与原始模型的输出是否匹配。

在典型的生成循环中,提示词(prompt)处理完并经过单次前向步骤后,序列长度为 1(预测的下一个 Token)的数据连同 kv-cache 一起输入到模型的前向传递中。在朴素的推测解码实现中,每个推测头都有自己的 kv-cache;但我们修改了 vLLM 项目中开发的分页注意力内核,以实现高效的 kv-cache 维护。这确保了在更大批处理量下吞吐量不会降低。此外,我们修改了注意力掩码以启用对 N+1 个 Token 的验证,从而在不偏离原始模型输出的情况下实现推测解码。该实现的详细信息记录在此处

结果

我们展示了使用简单的提示词对 Meta 的 Llama2 13B 聊天版本所获得的速度提升。

Visual illustration of the non-speculative generation (left) compared to speculative generation (right)

图 2:非推测性生成(左)与推测性生成(右)的视觉说明

我们将上述解决方案部署在内部生产环境中。下图报告了两个指标:首个 Token 时间 (TTFT) 和 Token 间延迟 (ITL),并涵盖了不同数量的并发用户(图中线条数字所示)。我们观察到,对于所有批处理大小,推测解码版本的 Llama2 13B 聊天模型速度几乎快了两倍,而 Granite 20B 代码模型则快了近三倍。对于较小的模型——IBM 的 Granite 7B 和 Meta Llama3 8B 模型——我们观察到了相似的表现。

Time to first token (TTFT - left) and Inter-token latency (ITL - right) for Llama 13B with number of concurrent users indicated on the graph

图 3:Llama 13B 的首个 Token 时间 (TTFT – 左) 和 Token 间延迟 (ITL – 右),图中显示了并发用户数

Time to first token (TTFT - left) and Inter-token latency (ITL - right) for Granite 20B Code with number of concurrent users indicated on the graph

图 4:Granite 20B Code 的首个 Token 时间 (TTFT – 左) 和 Token 间延迟 (ITL – 右),图中显示了并发用户数

关于效率的说明

我们进行了大量的实验来确定推测器训练的最佳配置。内容如下:

  1. 推测器架构:当前方法允许修改头的数量,这对应于我们可以向前看的 Token 数量。增加头的数量也会增加所需的额外计算量和训练复杂度。在实践中,对于语言模型,我们发现 3-4 个头效果良好;而对于代码模型,我们发现 6-8 个头可以带来更多收益。
  2. 计算:增加头的数量会导致计算量在两个维度上增加:一是单次前向传递的延迟增加,二是预测多个 Token 所需的计算量。如果推测器在头数较多时不够准确,会导致计算浪费,从而增加延迟并降低吞吐量。
  3. 内存:增加的计算量被每次前向传递所需的 HBM 往返次数所抵消。请注意,如果我们准确预测了 3 个 Token,我们就节省了 3 次 HBM 往返时间。

我们最终确定语言模型使用 3-4 个头,代码模型使用 6-8 个头。在 7B 到 20B 的不同模型规模中,我们观察到在不损失吞吐量的情况下,相比非推测解码,延迟有了显著改善。我们在批处理量超过 64 时开始观察到吞吐量下降,这种情况在实践中很少发生。

推测解码:训练

推测解码主要有两种方法:一种是利用较小的模型(例如,使用 Llama 7B 作为 Llama 70B 的推测器),另一种是附加推测头(并对其进行训练)。在我们的实验中,我们发现附加推测头的方法在模型质量和延迟增益方面更为有效。

推测器架构

Medusa 使推测解码变得普及;他们的方法是在现有模型上增加一个头,然后对其进行训练以进行推测。我们通过使“头”层次化来修改 Medusa 架构,其中每个头阶段预测一个 Token,然后将其输入到下一个头阶段。这些多阶段头如下图所示。我们正在探索通过在多个阶段和基础模型之间共享嵌入表来最小化嵌入表的方法。

A simple architecture diagram for a 3-headed multi-stage  speculator. Z is the state from the base model.

图 4:3 头多阶段推测器的简单架构图。Z 是来自基础模型的状态。

推测器训练

出于效率考虑,我们采用两阶段训练推测器的方法。在第一阶段,我们在小批次和长序列长度(4k tokens)上进行训练,并使用标准的因果语言模型(causal LM)方法。在第二阶段,我们使用由基础模型生成的大批次和短序列长度(256 tokens)。在此训练阶段,我们微调头以匹配基础模型的输出。通过多次实验,我们发现第一阶段与第二阶段的步数比为 5:2 时效果最佳。我们在下图中描绘了这些阶段的进展。我们使用 PyTorch FSDP 和 IBM FMS 来训练推测器。

Per-head training loss curves for Llama2-13B speculator training, phase 1 and 2

图 5:Llama2-13B 推测器训练第一阶段和第二阶段的每头训练损失曲线

结论与未来工作

通过本篇博文,我们发布了一种新的推测解码方法及以下资产:

  1. 用于改善一系列模型(Llama3 8B、Llama2 13B、Granite 7B 和 CodeLlama 13B)Token 间延迟的模型
  2. 生产质量的推理代码
  3. 训练推测器的方案

我们正在致力于训练 Llama3 70B 和 Mistral 模型的推测器,并邀请社区参与贡献,协助改进我们的框架。我们也希望能与 vLLMTGI 等主流开源服务框架合作,贡献我们的推测解码方法,造福社区。

致谢

有多个团队帮助我们实现了推理延迟的改进。我们要感谢 vLLM 团队以清晰且可重用的方式创建了分页注意力内核。我们向 Meta 的 PyTorch 团队表示感谢,他们为这篇博文提供了反馈,并在 PyTorch 的最佳使用方面做出了持续努力。特别感谢 IBM Research 的内部生产团队,他们将此原型投入生产并进行了加固。感谢 Stas Bekman 对博文提出了深刻见解,使得关于计算、内存和推测器有效性之间权衡的解释得到了改进。

分页注意力内核由 Josh Rosenkranz 和 Antoni Viros i Martin 集成到 IBM FMS 中。推测器架构和训练由 Davis Wertheimer、Pavithra Ranganathan 和 Sahil Suneja 完成。建模代码与推理服务器的集成由 Thomas Parnell、Nick Hill 和 Prashant Gupta 完成。