跳转到主要内容
博客

用于长上下文推理的 Flash-Decoding

作者: 2023 年 10 月 13 日2024 年 11 月 16 日暂无评论

动机

大型语言模型(LLM),例如 ChatGPT 或 Llama,最近受到了前所未有的关注。然而,它们的运行成本仍然高昂。尽管生成一个响应可能只需约 0.01 美元(AWS 上一个 8xA100 实例的几秒钟费用),但当扩展到数十亿用户时,这些成本会迅速累积,而这些用户可能每天与 LLM 进行多次交互。某些用例的成本更高,例如代码自动补全,因为它在每次输入新字符时都会运行。随着 LLM 应用的增多,即使是生成时间上的微小效率提升也能产生巨大影响。

LLM 推理(或称“解码”)是一个迭代过程:每次生成一个 token。生成 N 个 token 的完整句子需要模型进行 N 次前向传播。幸运的是,可以缓存先前计算的 token:这意味着单个生成步骤不依赖于上下文长度,除了一个操作,即注意力机制。该操作的性能随着上下文长度的增加而下降。

LLM 有许多重要的新兴用例,它们利用长上下文。通过更长的上下文,LLM 可以对更长的文档进行推理,无论是总结还是回答相关问题;它们可以跟踪更长的对话;甚至可以在编写代码之前处理整个代码库。例如,大多数 LLM 在 2022 年的上下文长度高达 2k (GPT-3),但我们现在有了扩展到 32k 的开源 LLM (Llama-2-32k),甚至最近达到了 100k (CodeLlama)。在这种情况下,注意力机制在推理期间占据了相当大的时间比例。

在批处理大小维度上进行扩展时,即使上下文相对较小,注意力机制也可能成为瓶颈。这是因为读取内存的量随着批处理维度的增加而扩展,而对于模型的其余部分,它仅取决于模型大小。

我们提出了一种名为 Flash-Decoding 的技术,可显著加快推理期间的注意力机制,使超长序列的生成速度最高提升 8 倍。主要思想是尽可能快地并行加载键和值,然后单独重新缩放和组合结果,以保持正确的注意力输出。

用于解码的多头注意力

在解码过程中,生成的每个新 token 都需要关注所有先前的 token,以计算

softmax(queries @ keys.transpose) @ values

在训练情况下,此操作已通过 FlashAttention(最近的 v1 和 v2)进行了优化,其中瓶颈是读取和写入中间结果(例如 Q @ K^T)的内存带宽。然而,这些优化不直接适用于推理情况,因为瓶颈不同。对于训练,FlashAttention 在批处理大小和查询长度维度上并行化。在推理过程中,查询长度通常为 1:这意味着如果批处理大小小于 GPU 上的流式多处理器 (SM) 数量(A100 为 108),则操作将仅使用 GPU 的一小部分!在使用长上下文时尤其如此,因为它需要较小的批处理大小才能适应 GPU 内存。当批处理大小为 1 时,FlashAttention 将使用不到 1% 的 GPU!

FlashAttention 仅在查询块和批处理大小之间并行化,并且在解码期间无法完全占用整个 GPU

注意力机制也可以使用矩阵乘法原语完成——不使用 FlashAttention。在这种情况下,操作完全占用 GPU,但启动许多写入和读取中间结果的内核,这不是最优的。

一种更快的解码注意力机制:Flash-Decoding

我们的新方法 Flash-Decoding 基于 FlashAttention,并增加了一个新的并行化维度:键/值序列长度。它结合了上述两种方法的优点。像 FlashAttention 一样,它在全局内存中存储的额外数据非常少,然而即使批处理大小很小,只要上下文长度足够大,它也能充分利用 GPU。

Flash-Decoding 还在键和值之间并行化,代价是一个小的最终归约步骤

Flash-Decoding 分为 3 个步骤

  1. 首先,我们将键/值分割成更小的块。
  2. 我们使用 FlashAttention 并行计算查询与这些分割块的注意力。我们还为每行和每个分割块写入一个额外的标量:注意力值的对数和指数。
  3. 最后,我们通过对所有分割块进行归约来计算实际输出,并使用对数和指数来缩放每个分割块的贡献。

所有这些都是可能的,因为注意力/softmax 可以迭代计算。在 Flash-Decoding 中,它在两个级别使用:在分割块内部(类似于 FlashAttention),以及在分割块之间执行最终归约。

实际上,步骤 (1) 不涉及任何 GPU 操作,因为键/值块是完整键/值张量的视图。然后我们有 2 个独立的内核分别执行 (2) 和 (3)。

CodeLlama 34B 基准测试

为了验证这种方法,我们对 CodeLLaMa-34b 的解码吞吐量进行了基准测试。该模型与 Llama 2 具有相同的架构,更普遍地说,结果应该适用于许多 LLM。我们测量了在各种序列长度(从 512 到 64k)下的解码速度(以 tok/s 为单位),并比较了多种计算注意力的方法

  • Pytorch:使用纯 PyTorch 原语运行注意力(不使用 FlashAttention)
  • FlashAttention v2
  • FasterTransformer:使用 FasterTransformer 注意力核
  • Flash-Decoding
  • 以及一个上限,其计算方式是读取整个模型以及 KV 缓存所需的时间

Flash-Decoding 为超长序列的解码速度带来了高达 8 倍的提升,并且比替代方法具有更好的扩展性。

CodeLlama

对于小型提示,所有方法表现相似,但随着序列长度从 512 增加到 64k,除了 Flash-Decoding,其他方法都扩展性不佳。在这种情况下(批量大小为 1),使用 Flash-Decoding 时,序列长度的缩放对生成速度影响很小

组件级微基准测试

我们还在 A100 上,使用 f16 输入,针对各种序列长度和批处理大小,对缩放多头注意力进行了微基准测试。我们将批处理大小设置为 1,并使用 16 个维度为 128 的查询头,以及 2 个键/值头(分组查询注意力),这与 CodeLLaMa-34b 在 4 个 GPU 上运行时使用的维度相匹配。

    
设置 \ 算法PyTorch Eager (微秒)Flash-Attention v2.0.9 (微秒)Flash-Decoding (微秒)
B=256, seqlen=2563058.6390.563.4
B=128, seqlen=5123151.4366.367.7
B=64, seqlen=10243160.4364.877.7
B=32, seqlen=20483158.335258.5
B=16, seqlen=40963157401.757
B=8, seqlen=81923173.1529.256.4
B=4, seqlen=163843223582.758.2
B=2, seqlen=327683224.11156.160.3
B=1, seqlen=655361335.62300.664.4
B=1, seqlen=13107226644592.2106.6

多头注意力微基准测试,运行时间以微秒计。随着序列长度扩展到 64k,Flash-Decoding 实现了几乎恒定的运行时间。

之前测量到的端到端高达 8 倍的速度提升之所以成为可能,是因为注意力本身比 FlashAttention 快了 50 倍。直到序列长度达到 32k,注意力时间大致保持不变,因为 Flash-Decoding 能够充分利用 GPU。

使用 Flash-Decoding

Flash-decoding 可在以下位置获取

  • FlashAttention 包中,从版本 2.2 开始
  • 通过 xFormers 从 0.0.22 版本开始,通过 `xformers.ops.memory_efficient_attention`。调度器将根据问题大小自动使用 Flash-Decoding 或 FlashAttention 方法。当这些方法不受支持时,它可以调度到一个实现 Flash-Decoding 算法的高效 triton 内核。

在 FlashAttention 仓库的 这里 和 xFormers 仓库的 这里,提供了 LLaMa v2 / CodeLLaMa 解码的完整示例。我们还提供了 一个极简示例,展示了 LLaMa v1/v2 模型的高效解码代码,旨在实现快速、易读、具有教育意义和可修改性。

致谢

感谢 Erich Elsen、Ashish Vaswani 和 Michaël Benesty 提出这种分割 KVcache 加载的想法。我们要感谢 Jeremy Reizenstein、Patrick Labatut 和 Andrew Tulloch 的宝贵讨论,以及 Quentin Carbonneaux 为 xFormers 贡献了高效解码示例。我们还要感谢 Geeta Chauhan 和 Gregory Chanan 协助撰写文章,并更广泛地为本文在 PyTorch 博客上发表做出了贡献。