跳转到主要内容
博客

推测解码搭车客指南

作者: 2024 年 5 月 2 日2024 年 11 月 13 日暂无评论

投机解码是一种用于推理的优化技术,它在生成当前 token 的同时,在一次前向传播中对未来的 token 进行有根据的猜测。它包含一种验证机制,以确保这些猜测 token 的正确性,从而保证投机解码的整体输出与传统解码的输出一致。优化大型语言模型(LLMs)的推理成本,可以说是降低生成式 AI 成本和提高其采用率的最关键因素之一。为了实现这一目标,有多种推理优化技术可用,包括自定义内核、输入请求的动态批处理以及大型模型的量化。

在这篇博客文章中,我们提供了投机解码的指南,并展示了它如何与其他优化共存。我们很荣幸能开源以下内容,其中包括首个用于 Llama3 模型的投机解码器:

  1. 适用于 Meta Llama3 8BIBM Granite 7B labMeta Llama2 13BMeta Code Llama2 13B 的投机解码器模型。
  2. 通过 IBM 的 HF TGI 分支进行推理的代码。
  3. 用于训练您自己的投机解码器及其相应秘籍的代码。

我们已将这些投机解码器部署到拥有数千日常用户的内部生产级环境中,并观察到语言模型(Llama3 8B、Llama2 13B 和 IBM Granite 7B)的速度提升了 2 倍,IBM 的 Granite 20B 代码模型的速度提升了 3 倍。我们在这篇技术报告中详细解释了我们的方法,并计划在即将发布的 ArXiv 论文中进行深入分析。

投机解码:推理

我们在内部生产环境运行 IBM TGIS,该环境包含连续批处理、融合内核和量化内核等优化。为了在 TGIS 中启用投机解码,我们修改了来自 vLLM 的分页注意力(paged attention)内核。接下来,我们将描述推理引擎为启用投机解码所做的关键更改。

投机解码基于这样一个前提:模型足够强大,可以在一次前向传播中预测多个 token。然而,当前的推理服务器仅优化为一次预测一个 token。在我们的方法中,我们除了通常的头部之外,还为 LLM 附加了多个投机解码头,以预测 *N+1-、N+2-、N+3-th...* token。例如,3 个头将预测 3 个额外的 token。投机解码器架构的详细信息将在本博客的后面部分解释。在推理过程中实现*效率*和*正确性*存在两个挑战——一个是在不复制 KV-缓存的情况下进行预测,另一个是验证预测是否与原始模型的输出匹配。

在典型的生成循环中,提示符在一次前向步骤中处理后,长度为 1(预测下一个 token)的序列与 kv-缓存一起输入到模型的前向传播中。在朴素的投机解码实现中,每个投机解码头都会有自己的 kv-缓存,但我们修改了 vLLM 项目中开发的分页注意力内核,以实现高效的 kv-缓存维护。这确保了在更大批次大小下吞吐量不会降低。此外,我们修改了注意力掩码,以实现对 *N+1* 个 token 的验证,从而在不偏离原始模型输出的情况下启用投机解码。此实现的详细信息已此处提供。

结果

我们使用一个简单的提示符展示了通过 Meta 的 Llama2 13B 聊天版本获得的加速。

Visual illustration of the non-speculative generation (left) compared to speculative generation (right)

图 2:非投机生成(左)与投机生成(右)的视觉示例

我们将上述解决方案部署到内部生产环境中。下图报告了两个指标——首个 token 时间 (TTFT) 和 token 间延迟 (ITL),并展示了不同并发用户数(在图表线条上的数字中显示)下的结果。我们观察到,对于所有批处理大小,投机解码版本对于 Llama2 13B 聊天模型的速度几乎是非投机版本的两倍,对于 Granite 20B 代码模型的速度几乎是三倍。我们观察到较小模型(IBM 的 Granite 7B 和 Meta Llama3 8B 模型)也有类似的行为。

Time to first token (TTFT - left) and Inter-token latency (ITL - right) for Llama 13B with number of concurrent users indicated on the graph

图 3:Llama 13B 的首个 token 时间 (TTFT – 左) 和 token 间延迟 (ITL – 右),图表上标明了并发用户数

Time to first token (TTFT - left) and Inter-token latency (ITL - right) for Granite 20B Code with number of concurrent users indicated on the graph

图 4:Granite 20B 代码模型的首个 token 时间 (TTFT – 左) 和 token 间延迟 (ITL – 右),图表上标明了并发用户数

效率说明

我们进行了大量实验来确定投机解码器训练的正确配置。这些配置是:

  1. 投机解码器架构:当前方法允许修改头的数量,这对应于我们可以向前看多少个 token。增加头的数量也会增加所需的额外计算量和训练复杂性。在实践中,对于语言模型,我们发现 3-4 个头效果很好,而对于代码模型,我们发现 6-8 个头可以带来好处。
  2. 计算:增加头数会导致计算量在两个维度上增加,一是单次前向传播的延迟增加,二是多个 token 所需的计算量增加。如果投机解码器在拥有更多头时不够准确,将导致计算浪费,从而增加延迟并降低吞吐量。
  3. 内存:增加的计算量被每次前向传播需要执行的 HBM 往返次数所抵消。请注意,如果我们将 3 个 token 的前瞻预测正确,我们就可以节省三次 HBM 往返时间。

我们最终确定语言模型使用 3-4 个头,代码模型使用 6-8 个头。在 7B 到 20B 的不同模型大小范围内,与非投机解码相比,我们观察到显著的延迟改进,同时没有吞吐量损失。当批处理大小超过 64 时,我们开始观察到吞吐量下降,这种情况在实践中很少发生。

投机解码:训练

投机解码有两种主要方法,一种是利用较小的模型(例如,Llama 7B 作为 Llama 70B 的投机解码器),另一种是附加投机解码器头(并对其进行训练)。在我们的实验中,我们发现附加投机解码器头的方法在模型质量和延迟增益方面都更有效。

投机解码器架构

Medusa 使投机解码变得流行;他们的方法是向现有模型添加一个头部,然后训练该头部进行投机。我们通过使“头部”分层来修改 Medusa 架构,其中每个头部阶段预测一个 token,然后将其馈送到下一个头部阶段。这些多阶段头部如下图所示。我们正在探索通过在多个阶段和基础模型之间共享嵌入表来最小化嵌入表的方法。

A simple architecture diagram for a 3-headed multi-stage  speculator. Z is the state from the base model.

图 4:一个简单的三头多阶段投机解码器架构图。Z 是来自基础模型的状态。

投机解码器训练

出于效率考虑,我们采用两阶段方法来训练投机解码器。在第一阶段,我们使用标准因果语言模型(Causal LM)方法在小批量和长序列长度(4k tokens)上进行训练。在第二阶段,我们使用从基础模型生成的短序列长度(256 tokens)的大批量数据。在此训练阶段,我们调整头部以匹配基础模型的输出。通过大量实验,我们发现第一阶段与第二阶段的步数比为 5:2 效果很好。下图展示了这些阶段的进展。我们使用 PyTorch FSDP 和 IBM FMS 来训练投机解码器。

Per-head training loss curves for Llama2-13B speculator training, phase 1 and 2

图 5:Llama2-13B 投机解码器训练的每头训练损失曲线,阶段 1 和阶段 2

结论和未来工作

通过这篇博客,我们发布了一种新的投机解码方法以及以下资产:

  1. 用于改进一系列模型(Llama3 8B、Llama2 13B、Granite 7B 和 CodeLlama 13B)的 token 间延迟的模型。
  2. 生产级推理代码
  3. 训练投机解码器的秘籍

我们正在努力训练 Llama3 70B 和 Mistral 模型的投机解码器,并邀请社区贡献并帮助改进我们的框架。我们也很乐意与主要的开源服务框架合作,例如 vLLMTGI,以回馈我们的投机解码方法,造福社区。

致谢

有几个团队帮助我们实现了这些推理延迟的改进。我们要感谢 vLLM 团队以清晰可重用的方式创建了分页注意力内核。我们向 Meta 的 PyTorch 团队表示感谢,他们为本博客提供了反馈,并持续努力优化 PyTorch 的使用。特别感谢 IBM Research 的内部生产团队,他们将这个原型投入生产并使其稳固。感谢 Stas Bekman 对本博客提供了富有洞察力的评论,从而改进了计算、内存和投机解码器效率之间权衡的解释。

分页注意力内核由 Josh Rosenkranz 和 Antoni Viros i Martin 集成到 IBM FMS 中。投机解码器架构和训练由 Davis Wertheimer、Pavithra Ranganathan 和 Sahil Suneja 完成。建模代码与推理服务器的集成由 Thomas Parnell、Nick Hill 和 Prashant Gupta 完成。