投机解码是一种用于推理的优化技术,它在生成当前 token 的同时,对未来的 token 进行有根据的猜测,所有这些都在一次前向传递中完成。它结合了一种验证机制,以确保这些投机 token 的正确性,从而保证投机解码的总体输出与 vanilla 解码的输出相同。优化大型语言模型 (LLM) 的推理成本可以说是降低生成式 AI 成本和提高其普及率的最关键因素之一。为了实现这一目标,有各种推理优化技术可供选择,包括自定义内核、输入请求的动态批处理以及大型模型的量化。
在这篇博文中,我们将提供投机解码指南,并演示它如何与其他优化共存。我们很自豪地开源了以下内容,其中包括 Llama3 模型的第一个投机器
- 用于 Meta Llama3 8B、IBM Granite 7B lab、Meta Llama2 13B 和 Meta Code Llama2 13B 的投机器模型。
- 通过 IBM 的 HF TGI 分支进行推理的代码。
- 用于训练您自己的投机器和相应配方的代码。
我们已在拥有数千名日活跃用户的内部生产级环境中部署了这些投机器,并观察到语言模型(Llama3 8B、Llama2 13B 和 IBM Granite 7B)的速度提升了 2 倍,IBM Granite 20B 代码模型的速度提升了 3 倍。我们在本技术报告中详细解释了我们的方法,并计划在即将发布的 ArXiv 论文中进行深入分析。
投机解码:推理
我们在内部生产环境中运行 IBM TGIS,该环境具有持续批处理、融合内核和量化内核等优化。为了在 TGIS 中启用投机解码,我们修改了来自 vLLM 的分页注意力内核。在下文中,我们将描述推理引擎的关键更改,以启用投机解码。
投机解码基于以下前提:模型足够强大,可以在一次前向传递中预测多个 token。但是,当前的推理服务器经过优化,每次只能预测单个 token。在我们的方法中,我们为 LLM 附加了多个投机头(除了通常的一个头),以预测第 N+1 个、第 N+2 个、第 N+3 个……token。例如,3 个头将预测 3 个额外的 token。投机器架构的详细信息将在本博客的后面部分进行解释。在推理过程中实现效率和正确性有两个挑战 - 一个是在不复制 KV 缓存的情况下进行预测,另一个是验证预测是否与原始模型的结果匹配。
在典型的生成循环中,在单个前向步骤中处理提示后,长度为 1 的序列(预测的下一个 token)与 kv 缓存一起馈送到模型的前向传递中。在朴素的投机解码实现中,每个投机头都有自己的 kv 缓存,但我们修改了 vLLM 项目中开发的分页注意力内核,以实现高效的 kv 缓存维护。这确保了吞吐量不会在较大的批处理大小下降低。此外,我们修改了注意力掩码以启用第 N+1 个 token 的验证,从而启用投机解码,而不会偏离原始模型的输出。此实现的详细信息在此处捕获。
结果
我们使用一个简单的提示来说明使用 Meta 的 Llama2 13B 聊天版本获得的速度提升。
图 2:非投机生成(左)与投机生成(右)的可视化图示
我们在内部生产环境中部署了上述解决方案。下图报告了两个指标 - 首个 token 的时间 (TTFT) 和 token 间延迟 (ITL),以及不同数量的并发用户(在图表线上的数字中捕获)。我们观察到,对于所有批处理大小,投机解码版本的 Llama2 13B 聊天模型速度几乎是非投机版本的两倍,Granite 20B 代码模型的速度几乎是非投机版本的三倍。我们观察到较小的模型(IBM 的 Granite 7B 和 Meta Llama3 8B 模型)也表现出类似的行为。
图 3:Llama 13B 的首个 token 的时间(TTFT - 左)和 token 间延迟(ITL - 右),并发用户数在图表上指示
图 4:Granite 20B 代码的首个 token 的时间(TTFT - 左)和 token 间延迟(ITL - 右),并发用户数在图表上指示
关于效率的说明
我们进行了大量实验,以确定投机器训练的正确配置。这些是
- 投机器架构:当前的方法允许修改头的数量,这映射到我们可以向前看的 token 数量。增加头的数量也会增加所需的额外计算量和训练的复杂性。在实践中,对于语言模型,我们发现 3-4 个头在实践中效果良好,而我们发现代码模型可以从 6-8 个头中获益。
- 计算:增加头的数量会导致两个维度的计算量增加,一个是单个前向传递的延迟增加,另一个是多个 token 所需的计算量增加。如果投机器在更多头的情况下不准确,则会导致浪费计算,从而增加延迟并降低吞吐量。
- 内存:增加的计算量被每次前向传递都需要完成的 HBM 往返所抵消。请注意,如果我们正确预测了 3 个 token 的前瞻,我们节省了 3 个 HBM 往返时间。
我们为语言模型确定了 3-4 个头,为代码模型确定了 6-8 个头,并且在从 7B 到 20B 的不同模型尺寸中,我们观察到与非投机解码相比,延迟显着改善,而吞吐量没有损失。我们开始观察到吞吐量在批处理大小超过 64 时会降低,这种情况在实践中很少发生。
投机解码:训练
投机解码有两种广泛的方法,一种是利用较小的模型(例如,Llama 7B 作为 Llama 70B 的投机器),另一种是附加投机器头(并对其进行训练)。在我们的实验中,我们发现附加投机器头的方法在模型质量和延迟增益方面都更有效。
投机器架构
Medusa 使投机解码流行起来;他们的方法是在现有模型中添加一个头,然后训练该头进行投机。我们通过使“头”分层来修改 Medusa 架构,其中每个头阶段预测单个 token,然后将其馈送到下一个头阶段。这些多阶段头如下图所示。我们正在探索通过在多个阶段和基础模型之间共享嵌入表来最小化嵌入表的方法。
图 4:3 头多阶段投机器的简单架构图。Z 是来自基础模型的状态。
投机器训练
我们采用两阶段方法来训练投机器,以提高效率。在第一阶段,我们使用长序列长度(4k token)的小批量进行训练,并使用标准因果 LM 方法进行训练。在第二阶段,我们使用从基础模型生成的小序列长度(256 token)的大批量。在此训练阶段,我们调整头以匹配基础模型的输出。通过大量实验,我们发现阶段 1 与阶段 2 的步骤比例为 5:2 时效果良好。下图描述了这些阶段的进展情况。我们使用 PyTorch FSDP 和 IBM FMS 进行投机器的训练。
图 5:Llama2-13B 投机器训练的每个头训练损失曲线,阶段 1 和 2
结论和未来工作
通过这篇博客,我们发布了一种新的投机解码方法和以下资产
- 用于提高一系列模型(Llama3 8B、Llama2 13B、Granite 7B 和 CodeLlama 13B)的 token 间延迟的模型
- 生产质量的推理代码
- 用于训练投机器的配方
我们正在努力训练 Llama3 70B 和 Mistral 模型的投机器,并邀请社区贡献力量并帮助改进我们的框架。我们也希望与主要的开源服务框架(如 vLLM 和 TGI)合作,将我们的投机解码方法贡献回社区,以造福社区。
致谢
有多个团队帮助我们实现了推理的这些延迟改进。我们要感谢 vLLM 团队以简洁且可重用的方式创建了分页注意力内核。我们向 Meta 的 PyTorch 团队表示感谢,他们帮助提供了有关此博客的反馈,并在 PyTorch 的最佳使用方面做出了持续努力。特别感谢我们在 IBM Research 的内部生产团队,他们将此原型投入生产并使其更加完善。还要向 Stas Bekman 致敬,他为博客提供了富有洞察力的评论,从而改进了对计算、内存和投机器有效性之间权衡的解释。
分页注意力内核由 Josh Rosenkranz 和 Antoni Viros i Martin 集成到 IBM FMS 中。投机器架构和训练由 Davis Wertheimer、Pavithra Ranganathan 和 Sahil Suneja 完成。建模代码与推理服务器的集成由 Thomas Parnell、Nick Hill 和 Prashant Gupta 完成。