作者:Tri Dao, Daniel Haziza, Francisco Massa, Grigory Sizov

动机

大型语言模型 (LLM),如 ChatGPT 或 Llama,最近受到了前所未有的关注。然而,它们的运行成本仍然非常高昂。即使生成一个响应可能只花费约 $0.01(在 AWS 上使用 8 个 A100 实例运行几秒钟),但当扩展到数十亿用户时,这些成本会迅速累积,因为用户每天可能与这些 LLM 进行多次交互。有些使用案例成本更高,例如代码自动补全,因为它会在每次输入新字符时运行。随着 LLM 应用数量激增,即使是生成时间上的微小效率提升,也能产生巨大影响。

LLM 推理(或称“解码”)是一个迭代过程:一次生成一个 token。生成包含 N 个 token 的完整句子需要通过模型进行 N 次前向传播。幸运的是,可以缓存之前计算的 token:这意味着除了注意力(attention)这个操作外,单个生成步骤并不取决于上下文长度。注意力操作随着上下文长度的增加而扩展性较差。

LLM 的许多重要新兴使用案例都利用了长上下文。通过更长的上下文,LLM 可以对更长的文档进行推理,无论是总结还是回答关于文档的问题;它们可以跟踪更长的对话;甚至可以在编写代码之前处理整个代码库。例如,2022 年大多数 LLM 的上下文长度不超过 2k (GPT-3),但我们现在拥有扩展到 32k 的开源 LLM(Llama-2-32k),或者最近甚至达到了 100k(CodeLlama)。在这种设置下,注意力操作在推理过程中占据了相当大的一部分时间。

在批量大小(batch size)维度上进行扩展时,即使在相对较小的上下文中,注意力也可能成为瓶颈。这是因为读取内存的数据量随着批量维度的增加而增加,而模型其余部分的内存读取量仅取决于模型大小。

我们提出了一种名为 Flash-Decoding 的技术,它显著加快了推理过程中的注意力计算速度,对于非常长的序列,生成速度提高了高达 8 倍。主要思想是并行加载键(keys)和值(values),速度越快越好,然后单独重新缩放并组合结果,以获得正确的注意力输出。

用于解码的多头注意力

在解码过程中,生成的每个新 token 都需要关注所有先前的 token,以计算

softmax(queries @ keys.transpose) @ values

在训练场景中,FlashAttention(最近发布了 v1 和 v2)对这个操作进行了优化,其中的瓶颈在于读写中间结果(例如 Q @ K^T)所需的内存带宽。然而,这些优化不能直接应用于推理场景,因为瓶颈不同。在训练时,FlashAttention 在批量大小和查询长度维度上并行。在推理时,查询长度通常为 1:这意味着如果批量大小小于 GPU 上的流式多处理器 (SM) 数量(A100 上为 108 个),操作将只使用 GPU 的一小部分!在使用长上下文时尤其如此,因为长上下文需要更小的批量大小才能适应 GPU 内存。当批量大小为 1 时,FlashAttention 将使用不到 GPU 的 1%!

FlashAttention

FlashAttention 仅在查询块和批量大小维度上进行并行,无法在解码过程中完全占用整个 GPU

注意力计算也可以使用矩阵乘法原语来完成,而不使用 FlashAttention。在这种情况下,操作可以完全占用 GPU,但会启动许多写入和读取中间结果的核函数(kernels),这并不是最优的。

用于解码的更快注意力计算:Flash-Decoding

我们的新方法 Flash-Decoding 基于 FlashAttention,并增加了一个新的并行维度:键/值(keys/values)序列长度。它结合了上述两种方法的优点。像 FlashAttention 一样,它向全局内存存储的额外数据非常少,但即使在批量大小很小的情况下,只要上下文长度足够大,它也能充分利用 GPU。

Flash-Decoding

Flash-Decoding 还在键(keys)和值(values)维度上并行,代价是最后需要一个小的归约(reduction)步骤

Flash-Decoding 分三个步骤工作

  1. 首先,我们将键/值(keys/values)分割成更小的块。
  2. 我们使用 FlashAttention 并行计算查询(query)与每个分割块的注意力。我们还为每行和每个分割块写入 1 个额外的标量:即注意力值的 log-sum-exp。
  3. 最后,我们通过对所有分割块进行归约(reducing)来计算实际输出,使用 log-sum-exp 来缩放每个分割块的贡献。

所有这一切之所以可能,是因为注意力/softmax 可以迭代计算。在 Flash-Decoding 中,它在两个层面使用:在分割块内部(类似于 FlashAttention),以及跨分割块执行最终的归约。

实际上,步骤 (1) 不涉及任何 GPU 操作,因为键/值块是完整键/值张量的视图。然后,我们有两个独立的核函数(kernels)分别执行步骤 (2) 和 (3)。

CodeLlama 34B 基准测试

为了验证这种方法,我们对 CodeLLaMa-34b 的解码吞吐量进行了基准测试。该模型与 Llama 2 具有相同的架构,更普遍地说,结果应该可以泛化到许多 LLM。我们测量了在不同序列长度(从 512 到 64k)下的解码速度(单位为 token/秒),并比较了多种注意力计算方法

  • PyTorch: 使用纯 PyTorch 原语进行注意力计算(不使用 FlashAttention)
  • FlashAttention v2
  • FasterTransformer: 使用 FasterTransformer 注意力核函数(kernel)
  • Flash-Decoding
  • 以及一个上限,其计算方法是读取整个模型及其 KV 缓存(KV-cache)所需的时间

Flash-Decoding 将非常大序列的解码速度提高了高达 8 倍,并且比其他方法具有更好的扩展性。

CodeLlama

所有方法在处理短提示(prompts)时性能相似,但随着序列长度从 512 增加到 64k,除了 Flash-Decoding 外,其他方法的扩展性都很差。在这种设置下(批量大小为 1),使用 Flash-Decoding 时,序列长度的增加对生成速度影响很小

组件级微基准测试

我们还在 A100 上对不同序列长度和批量大小下的缩放多头注意力(scaled multi-head attention)进行了微基准测试,输入数据格式为 f16。我们将批量大小设置为 1,使用 16 个维度为 128 的查询头(query heads),以及 2 个键/值头(key/value heads,即分组查询注意力 grouped-query attention),这与 CodeLLaMa-34b 在 4 个 GPU 上运行时使用的维度相匹配。

       
设置 \\ 算法 PyTorch Eager (微秒) Flash-Attention v2.0.9 (微秒) Flash-Decoding (微秒)
B=256, seqlen=256 3058.6 390.5 63.4
B=128, seqlen=512 3151.4 366.3 67.7
B=64, seqlen=1024 3160.4 364.8 77.7
B=32, seqlen=2048 3158.3 352 58.5
B=16, seqlen=4096 3157 401.7 57
B=8, seqlen=8192 3173.1 529.2 56.4
B=4, seqlen=16384 3223 582.7 58.2
B=2, seqlen=32768 3224.1 1156.1 60.3
B=1, seqlen=65536 1335.6 2300.6 64.4
B=1, seqlen=131072 2664 4592.2 106.6

多头注意力微基准测试,运行时间以微秒为单位。Flash-Decoding 在序列长度扩展到 64k 时实现了几乎恒定的运行时间。

之前测量到的端到端高达 8 倍的加速,是因为注意力计算本身比 FlashAttention 快了高达 50 倍。直到序列长度达到 32k,注意力计算时间大致保持恒定,因为 Flash-Decoding 能够充分利用 GPU。

使用 Flash-Decoding

Flash-decoding 可用

  • FlashAttention 包中,版本 2.2 及以上
  • 通过 xFormers,从版本 0.0.22 开始,通过 xformers.ops.memory_efficient_attention。调度器将根据问题大小自动使用 Flash-Decoding 或 FlashAttention 方法。当这些方法不受支持时,它可以调度到一个实现 Flash-Decoding 算法的高效 triton 核函数(kernel)。

使用 LLaMa v2 / CodeLLaMa 进行解码的完整示例可在 FlashAttention 仓库的 此处 以及 xFormers 仓库的 此处 找到。我们还提供了一个针对 LLaMa v1/v2 模型的高效解码代码的 极简示例,旨在实现快速、易读、具有教育意义且可修改。

致谢

感谢 Erich Elsen, Ashish Vaswani 和 Michaël Benesty 提出了分割 KV 缓存加载的想法。我们要感谢 Jeremy Reizenstein, Patrick Labatut 和 Andrew Tulloch 提供了宝贵的讨论,感谢 Quentin Carbonneaux 为 xFormers 贡献了高效的解码示例。我们还要感谢 Geeta Chauhan 和 Gregory Chanan 在撰写本文以及推动其在 PyTorch 博客上发表方面提供的帮助和贡献。