博客

用于长上下文推理的 Flash-Decoding

作者: 2023年10月13日2024年11月16日暂无评论

动机

以 ChatGPT 或 Llama 为代表的大语言模型(LLM)近期受到了前所未有的关注。然而,它们的运行成本依然极其高昂。尽管生成单个响应可能仅需约 0.01 美元(在 AWS 上使用 8xA100 实例运行几秒钟的成本),但当扩展到数十亿用户时,这些成本会迅速累积,尤其是这些用户每天可能与 LLM 进行多次交互。某些使用场景(如代码自动补全)成本更高,因为每输入一个新字符都会触发运行。随着 LLM 应用的不断增多,即使是生成时间上的微小效率提升,也能产生巨大的影响。

LLM 推理(或称“解码”)是一个迭代过程:标记(tokens)是一个接一个生成的。生成 N 个标记的完整句子需要对模型进行 N 次前向传播。幸运的是,可以缓存先前计算出的标记:这意味着除了注意力(attention)这一单一操作外,单次生成步骤不依赖于上下文长度。而注意力操作的规模无法很好地随上下文长度进行扩展。

目前,利用长上下文的 LLM 重要新兴用例越来越多。有了更长的上下文,LLM 可以分析更长的文档(用于总结或回答相关问题)、跟踪更长的对话,甚至在编写代码前处理整个代码库。例如,在 2022 年,大多数 LLM 的上下文长度上限为 2k(GPT-3),但如今我们已经拥有扩展到 32k(Llama-2-32k)甚至最近达到 100k(CodeLlama)的开源 LLM。在这种设置下,注意力机制占用了推理过程中相当大的一部分时间。

当在批处理大小(batch size)维度进行扩展时,即使在相对较小的上下文环境下,注意力机制也可能成为瓶颈。这是因为读取内存的数据量随批处理维度扩展,而模型其余部分仅取决于模型本身的大小。

我们提出了一种名为 Flash-Decoding 的技术,它可以显著加速推理过程中的注意力计算,使超长序列的生成速度提升高达 8 倍。其核心思想是尽可能快地并行加载键(keys)和值(values),然后分别对结果进行重缩放和合并,以保持正确的注意力输出。

用于解码的多头注意力

在解码过程中,生成的每一个新标记都需要关注所有之前的标记,以计算:

softmax(queries @ keys.transpose) @ values

这一操作在训练场景中已通过 FlashAttention(近期为 v1 和 v2)进行了优化,因为训练时的瓶颈在于读取和写入中间结果(例如 Q @ K^T)的内存带宽。然而,这些优化并不直接适用于推理场景,因为两者的瓶颈不同。对于训练,FlashAttention 在批处理大小和查询长度维度上进行并行化。在推理期间,查询长度通常为 1:这意味着如果批处理大小小于 GPU 上的流式多处理器(SM)数量(A100 为 108 个),该操作将仅使用 GPU 的一小部分!在使用长上下文时情况尤为严重,因为这需要更小的批处理大小以适配 GPU 内存。当批处理大小为 1 时,FlashAttention 的 GPU 利用率不到 1%!

FlashAttention 仅在查询块和批处理大小维度上进行并行化,在解码过程中无法实现对整个 GPU 的充分利用。

注意力计算也可以使用矩阵乘法原语(不使用 FlashAttention)来完成。在这种情况下,操作会完全占用 GPU,但会启动许多写入和读取中间结果的内核,这并不是最优的。

更快的解码注意力:Flash-Decoding

我们提出的新方法 Flash-Decoding 基于 FlashAttention,并增加了一个新的并行维度:键/值(keys/values)序列长度。它结合了上述两种方法的优势。与 FlashAttention 一样,它在全局内存中存储的额外数据极少;然而,只要上下文长度足够大,即使在批处理大小很小时,它也能充分利用 GPU。

Flash-Decoding 还在键和值维度上进行并行化,代价是需要一个较小的最终缩减(reduction)步骤。

Flash-Decoding 分为 3 个步骤:

  1. 首先,我们将键/值切分成更小的块。
  2. 我们使用 FlashAttention 并行计算查询与这些切块各自的注意力。我们还在每一行和每一切块中额外写入一个标量:注意力值的 log-sum-exp。
  3. 最后,我们通过对所有切块进行缩减(reducing)来计算最终输出,并使用 log-sum-exp 来缩放每个切块的贡献。

这一切之所以可能,是因为注意力/softmax 可以迭代计算。在 Flash-Decoding 中,它在两个层级上使用:在切块内部(类似于 FlashAttention),以及在切块之间进行最终的缩减。

在实践中,步骤 (1) 不涉及任何 GPU 操作,因为键/值块只是完整键/值张量的视图。随后我们有两个独立的内核分别执行 (2) 和 (3)。

CodeLlama 34B 上的基准测试

为了验证该方法,我们对 CodeLLaMa-34b 的解码吞吐量进行了基准测试。该模型与 Llama 2 具有相同的架构,通常情况下,这些结果可以推广到许多 LLM。我们测量了序列长度从 512 到 64k 的解码速度(单位:tok/s),并比较了多种注意力计算方式:

  • Pytorch:使用纯 PyTorch 原语运行注意力(不使用 FlashAttention)
  • FlashAttention v2
  • FasterTransformer:使用 FasterTransformer 注意力内核
  • Flash-Decoding
  • 以及一个计算出的理论上限:即从内存中读取整个模型及 KV-cache 所需的时间。

Flash-Decoding 为超大序列解锁了高达 8 倍的解码速度提升,并且比替代方法的扩展性更好。

CodeLlama

所有方法在短提示(prompts)下表现相似,但随着序列长度从 512 增加到 64k,除 Flash-Decoding 外,其他方法的性能均大幅下降。在这种模式下(批处理大小为 1),使用 Flash-Decoding 时,增加序列长度对生成速度的影响极小。

组件级微基准测试

我们还在 A100 上针对各种序列长度和批处理大小对 f16 输入的缩放多头注意力进行了微基准测试。我们将批处理大小设为 1,使用 16 个维度为 128 的查询头,以及 2 个键/值头(分组查询注意力),这与 CodeLLaMa-34b 在 4 个 GPU 上运行时使用的维度相匹配。

    
设置 \ 算法PyTorch Eager (微秒)Flash-Attention v2.0.9 (微秒)Flash-Decoding (微秒)
B=256, 序列长度=2563058.6390.563.4
B=128, 序列长度=5123151.4366.367.7
B=64, 序列长度=10243160.4364.877.7
B=32, 序列长度=20483158.335258.5
B=16, 序列长度=40963157401.757
B=8, 序列长度=81923173.1529.256.4
B=4, 序列长度=163843223582.758.2
B=2, 序列长度=327683224.11156.160.3
B=1, 序列长度=655361335.62300.664.4
B=1, 序列长度=13107226644592.2106.6

多头注意力微基准测试,运行时间(微秒)。随着序列长度扩展至 64k,Flash-Decoding 实现了几乎恒定的运行时间。

之前测得的高达 8 倍的端到端加速得以实现,是因为注意力计算本身比 FlashAttention 快达 50 倍。在序列长度达到 32k 之前,注意力时间大致保持不变,因为 Flash-Decoding 成功实现了对 GPU 的充分利用。

使用 Flash-Decoding

Flash-Decoding 可通过以下方式获取:

  • FlashAttention 软件包,从 2.2 版本开始。
  • 通过 xFormers 从 0.0.22 版本开始,通过 `xformers.ops.memory_efficient_attention` 使用。调度程序将根据问题大小自动使用 Flash-Decoding 或 FlashAttention 方法。当这些方法不支持时,它会调度到一个高效的 Triton 内核,该内核实现了 Flash-Decoding 算法。

FlashAttention 仓库中提供了 LLaMa v2 / CodeLLaMa 解码的完整示例(点击此处),xFormers 仓库中也有提供(此处)。我们还提供了一个 LLaMa v1/v2 模型的高效解码最小化示例,旨在做到快速、易读、具有教育意义且易于修改。

致谢

感谢 Erich Elsen、Ashish Vaswani 和 Michaël Benesty 提出拆分 KVcache 加载这一想法。我们要感谢 Jeremy Reizenstein、Patrick Labatut 和 Andrew Tulloch 的宝贵讨论,以及 Quentin Carbonneaux 为 xFormers 贡献了高效的解码示例。我们还要感谢 Geeta Chauhan 和 Gregory Chanan 在撰写过程中提供的帮助,以及更广泛地为在 PyTorch 博客上发布本文所做的贡献。