博客

推测解码搭车客指南

作者 2024 年 5 月 2 日2024 年 11 月 13 日无评论

投机性解码是一种推理优化技术,它在生成当前 token 的同时,对未来 token 进行有根据的猜测,所有这些都在一次前向传播中完成。它包含一个验证机制,以确保这些猜测 token 的正确性,从而保证投机性解码的整体输出与普通解码的输出完全相同。优化大型语言模型 (LLM) 的推理成本,可以说是降低生成式 AI 成本和提高其普及率最关键的因素之一。为了实现这一目标,有各种推理优化技术可用,包括自定义内核、输入请求的动态批处理以及大型模型的量化。

在这篇博文中,我们提供一份投机性解码的指南,并演示它如何与其他优化技术共存。我们很荣幸开源以下内容,其中包括 Llama3 模型的第一款投机器。

  1. Meta Llama3 8B、IBM Granite 7B lab、Meta Llama2 13B 和 Meta Code Llama2 13B 的投机器。
  2. 通过 IBM 的 HF TGI 分支进行推理的代码。
  3. 训练您自己的投机器及其相应配方的代码。

我们在一个具有数千日常用户的内部生产级环境中部署了这些投机器,并观察到语言模型(Llama3 8B、Llama2 13B 和 IBM Granite 7B)的速度提升了 2 倍,在 IBM Granite 20B 代码模型上速度提升了 3 倍。我们在技术报告中详细解释了我们的方法,并计划在即将发布的 ArXiv 论文中进行深入分析。

投机性解码:推理

我们在内部生产环境中运行 IBM TGIS,该环境具有连续批处理、融合内核和量化内核等优化。为了在 TGIS 中启用投机性解码,我们修改了 vLLM 项目中的分页注意力内核。接下来,我们将描述推理引擎为启用投机性解码所做的关键更改。

投机性解码基于这样的前提:模型足够强大,可以在一次前向传播中预测多个 token。然而,当前的推理服务器已优化为一次仅预测一个 token。在我们的方法中,我们在 LLM 上附加了多个投机头(除了通常的那个)来预测 *N+1-、N+2-、N+3-th* 等 token。例如,3 个头将预测 3 个额外的 token。投机器架构的详细信息将在本博客的稍后部分进行解释。在推理过程中,实现*效率*和*正确性*存在两个挑战——一个是如何在不复制 KV 缓存的情况下进行预测,另一个是如何验证预测是否与原始模型的输出匹配。

在典型的生成循环中,在通过一次前向步骤处理完提示后,将长度为 1 的序列(预测的下一个 token)与 kv 缓存一起输入模型的正向传播。在朴素的投机性解码实现中,每个投机头都会有自己的 kv 缓存,但我们修改了 vLLM 项目中开发的分页注意力内核,以实现高效的 kv 缓存维护。这确保了在更大的批次大小时吞吐量不会降低。此外,我们修改了注意力掩码,以启用对 *N+1'th* token 的验证,从而在不偏离原始模型输出的情况下启用投机性解码。此实现的详细信息在此捕获。

结果

我们使用一个简单的提示来说明使用 Meta 的 Llama2 13B 聊天版本所获得的加速效果。

Visual illustration of the non-speculative generation (left) compared to speculative generation (right)

图 2:非投机生成(左)与投机生成(右)的视觉对比

我们在内部生产环境中部署了上述解决方案。下图报告了两个指标——首次 token 时间 (TTFT) 和 token 间延迟 (ITL),以及不同数量的并发用户(图中的曲线数字表示)。我们观察到,与所有批次大小的非投机版本相比,投机性解码版本在 Llama2 13B 聊天模型上的速度几乎快了 2 倍,在 Granite 20B 代码模型上的速度几乎快了 3 倍。我们对较小的模型(IBM 的 Granite 7B 和 Meta Llama3 8B 模型)也观察到了类似的行为。

Time to first token (TTFT - left) and Inter-token latency (ITL - right) for Llama 13B with number of concurrent users indicated on the graph

图 3:Llama 13B 的首次 token 时间 (TTFT – 左) 和 token 间延迟 (ITL – 右),图中的数字表示并发用户数

Time to first token (TTFT - left) and Inter-token latency (ITL - right) for Granite 20B Code with number of concurrent users indicated on the graph

图 4:Granite 20B 代码模型的首次 token 时间 (TTFT – 左) 和 token 间延迟 (ITL – 右),图中的数字表示并发用户数

关于效率的说明

我们进行了大量实验,以确定投机器训练的正确配置。这些实验包括:

  1. 投机器架构:当前的方法允许修改头的数量,这对应于我们可以超前的 token 数量。增加头的数量也会增加所需的额外计算量和训练的复杂性。在实践中,对于语言模型,我们发现 3-4 个头效果很好,而对于代码模型,我们发现 6-8 个头更有益。
  2. 计算:增加头的数量会在两个方面增加计算量:一是单个前向传播的延迟增加,二是多 token 计算所需的工作量。如果投机器在更多头上不够准确,就会导致计算浪费,增加延迟并降低吞吐量。
  3. 内存:计算量的增加通过每次前向传播需要执行的 HBM 往返次数来抵消。请注意,如果我们正确地获得了 3 个 token 的前瞻,我们就节省了三次 HBM 往返时间。

我们为语言模型选择了 3-4 个头,为代码模型选择了 6-8 个头,并且在从 7B 到 20B 的各种模型大小之间,我们观察到与非投机解码相比,延迟显着提高,而吞吐量没有损失。我们开始在批次大小超过 64 时观察到吞吐量下降,这种情况在实践中很少发生。

投机性解码:训练

投机性解码有两种广泛的方法:一种是利用一个较小的模型(例如,Llama 7B 作为 Llama 70B 的投机器),另一种是附加投机头(并训练它们)。在我们的实验中,我们发现附加投机头的方法在模型质量和延迟收益方面都更有效。

投机器架构

Medusa 使投机性解码流行起来;他们的方法是在现有模型上添加一个头,然后训练这个头进行投机。我们修改了 Medusa 架构,使“头”分层化,其中每个头阶段预测一个 token,然后将其馈送到下一头阶段。这些多阶段头如下图所示。我们正在探索通过在多个阶段和基础模型之间共享嵌入表来最小化嵌入表的方法。

A simple architecture diagram for a 3-headed multi-stage  speculator. Z is the state from the base model.

图 4:3 头多阶段投机器的简单架构图。Z 是基础模型的状态。

投机器训练

出于效率原因,我们采用了两阶段方法来训练投机器。在第一阶段,我们在具有长序列长度(4k token)的小批量上进行训练,并使用标准的因果 LM 方法进行训练。在第二阶段,我们使用从基础模型生成的短序列长度(256 token)的大批量。在此训练阶段,我们调整头以匹配基础模型的输出。通过大量实验,我们发现 5:2 的阶段 1 与阶段 2 的步数比例效果很好。我们在下图描绘了这些阶段的进展。我们使用 PyTorch FSDP 和 IBM FMS 进行投机器的训练。

Per-head training loss curves for Llama2-13B speculator training, phase 1 and 2

图 5:Llama2-13B 投机器训练(阶段 1 和 2)的每头训练损失曲线

结论和未来工作

通过这篇博客,我们发布了一种新的投机性解码方法以及以下资产:

  1. 用于提高一系列模型(Llama3 8B、Llama2 13B、Granite 7B 和 CodeLlama 13B)的 token 间延迟的模型。
  2. 生产质量的推理代码。
  3. 投机器训练的配方。

我们正在努力为 Llama3 70B 和 Mistral 模型训练投机器,并邀请社区贡献以及帮助改进我们的框架。我们还很乐意与 vLLM 和 TGI 等主要的开源服务框架合作,将我们的投机性解码方法贡献回去,以造福社区。

致谢

有几个团队帮助我们实现了推理延迟的提高。我们要感谢 vLLM 团队以一种干净且可重用的方式创建了分页注意力内核。我们感谢 Meta 的 PyTorch 团队,他们就这篇博文提供了反馈,并继续致力于 PyTorch 的最佳使用。特别感谢我们 IBM Research 的内部生产团队,他们将这个原型投入生产并进行了加固。特别感谢 Stas Bekman,他提供了对博文的深刻评论,从而改进了对计算、内存和投机器有效性之间权衡的解释。

分页注意力内核由 Josh Rosenkranz 和 Antoni Viros i Martin 集成到 IBM FMS 中。投机器架构和训练由 Davis Wertheimer、Pavithra Ranganathan 和 Sahil Suneja 完成。建模代码与推理服务器的集成由 Thomas Parnell、Nick Hill 和 Prashant Gupta 完成。