跳转到主要内容

低精度 KV 缓存的高效解码分组查询注意力

引言

生成式人工智能凭借其像人类一样生成内容的能力席卷了全球。许多生成式人工智能工具都由大型语言模型(LLM)驱动,例如 Meta 的 Llama 模型和 OpenAI 的 ChatGPT。LLM 面临的主要挑战之一是支持大“上下文长度”(也称为“序列长度”)。上下文长度是指模型用于理解输入上下文并生成响应的令牌数量。更长的上下文长度通常意味着响应具有更高的精度和质量。然而,长上下文长度在计算和内存方面都是密集的。这主要是由于以下原因:

  • 注意力层的计算复杂度与上下文长度成比例增长(增长率取决于注意力算法)。因此,在使用长上下文长度时,注意力层可能会成为瓶颈,尤其是在预填充阶段,此时注意力是计算密集型的。
  • KV 缓存大小与上下文长度线性增长,从而对内存需求施加更高的压力,进而减慢已经内存受限的注意力解码。此外,由于内存容量有限,当 KV 缓存变大时,批处理大小会减小,这通常会导致吞吐量下降。

与其他上述问题相比,计算复杂度的增长更难解决。解决 KV 缓存大小增长问题的一种方法是使用低精度 KV 缓存。根据我们的实验,在 Meta Llama 2 推理的解码阶段,分组 INT4 量化在准确性方面与 BF16 KV 缓存提供了可比的结果。然而,我们没有观察到任何延迟改进,尽管注意力解码层读取的数据减少了 4 倍。这意味着 INT4 注意力在利用宝贵的 HBM 带宽方面比 BF16 注意力效率低 4 倍。

在这份说明中,我们讨论了我们对 INT4 GQA(分组查询注意力——我们在 LLM 推理阶段使用的注意力层)应用的 CUDA 优化,以将其性能在 NVIDIA A100 GPU 上提高高达 1.8 倍,在 NVIDIA H100 GPU 上提高高达 1.9 倍

  • 优化的 CUDA INT4 GQAINT4 Flash-Decoding GQA(我们在上述实验中使用的性能最佳的 INT4 GQA)在 A100 上性能提高了 1.4 倍至 1.7 倍,在 H100 上性能提高了 1.09 倍至 1.3 倍
  • 优化的 CUDA INT4 GQABF16 Flash-Decoding GQAA100 上性能提高了 1.5 倍至 1.7 倍,在 H100 上性能提高了 1.4 倍至 1.7 倍

背景

用于 LLM 推理的 GQA

分组查询注意力(GQA)是多头注意力(MHA)的一种变体,其中每个 KV 缓存头在查询头组之间共享。我们的 LLM 推理在预填充和解码阶段都采用 GQA 作为注意力层,以减少 KV 缓存的容量需求。我们在推理中使用多个 GPU,其中 KV 缓存和查询头分布在不同的 GPU 上。每个 GPU 运行一个带有一个 KV 头和一组 Q 头的注意力层。因此,从单个 GPU 的角度来看,GQA 组件也可以描述为 MQA(多查询注意力)

解码 GQA 的简化工作流程如图 1 所示。GQA 接受三个主要输入:输入查询(表示为 `Q`)、K 缓存(表示为 `K`)和 V 缓存(表示为 `V`)。我们当前的 GQA 推理对 `Q`、`K` 和 `V` 使用 BF16。

  • Q 是一个形状为 (`B`, `1`, `HQ`, `D`) 的 4D BF16 张量
  • K 是一个形状为 (`B`, `Tmax`, `HKV`, `D`) 的 4D BF16 张量
  • V 是一个形状为 (`B`, `Tmax`, `HKV`, `D`) 的 4D BF16 张量

其中

  • B 是批处理大小(输入提示的数量)
  • HQ 是查询头的数量
  • HKV 是 KV 头的数量(HQ 必须能被 HKV 整除)
  • Tmax 是最大上下文长度
  • D 是头维度(固定为 128)

GQA 简单地表示为 `bmm(softmax(bmm(Q, KT) / sqrt(D)), V)`。这会产生一个与 `Q` 具有相同形状的 4D BF16 张量(表示为 `O`)作为单个输出张量。请注意,矩阵乘法使用 BF16 执行,但是累加和 `softmax` 以 FP32 进行。我们将此称为“BF16 GQA”,因为 KV 缓存是 BF16。

Figure 1: The simplified workflow of BF16 GQA for LLM inference

图 1 LLM 推理中 BF16 GQA 的简化工作流程

INT4 GQA

为了进一步减小 KV 缓存的大小,我们探索了使用 INT4 代替 BF16 作为 KV 缓存的可能性。我们通过计算 INT4 GQA 的计算强度(CI)并将其与 BF16 GQA 的计算强度进行比较来估计潜在的性能改进,因为 CI 代表每字节的 FLOPS。我们计算 `QKT` 和 `PV` 的 CI(如方程 1 所示),因为它们将 KV 缓存作为操作数。请注意,我们忽略了 `Q` 加载,因为它与 KV 缓存相比可以忽略不计。我们还忽略了任何不在全局内存上的中间数据加载/存储。因此,CI 仅考虑计算 FLOPS 和 KV 缓存加载。

Equation 1

方程 (1)

假设 `HQ` = 8 且 `HKV` = 1,BF16 KV 缓存的 CI 为 8,而 INT4 KV 缓存的 CI 为 32。CI 表明 BF16 和 INT4 GQA 都是内存受限的(A100 和 H100 的 BF16 张量核心的峰值 CI 分别为 312 TF / 2 TB/s = 141990 TF / 3.35 TB/s = 269;请注意,这些 TF 数字不考虑稀疏性)。此外,使用 INT4 KV 缓存,我们应该期望性能比 BF16 GQA 提高高达 4 倍。

为了在 GQA 中启用 INT4 KV 缓存支持,我们可以在将 KV 缓存传递给 BF16 GQA 运算符之前,将其从 INT4 反量化为 BF16。然而,由于 KV 缓存通常很大,从/向全局内存复制可能代价高昂。此外,解码 GQA 是一种内存受限操作(内存单元比计算单元利用率高得多)。图 2 显示了 xFormers 中 FMHA CUTLASS BF16 GQA 内核的 NCU 配置文件,这是 GQA 最先进的实现之一。从图中可以明显看出,内存是瓶颈。

Figure 2: The NCU profile of the FMHA CUTLASS BF16 kernel in xFormers

图 2 xFormers 中 FMHA CUTLASS BF16 内核的 NCU 配置文件

更有效的替代方案是将 INT4 反量化与 GQA 操作融合(如图 3 所示)。换句话说,让 GQA 直接读取 INT4 KV 缓存并在内核中执行 INT4 到 BF16 的转换。这种改变可能会减少 KV 缓存所需的全局内存读取量,从而可能导致延迟降低。我们称之为“INT4 GQA”。

Figure 3: The workflow of fused INT4 GQA

图 3 融合 INT4 GQA 的工作流程

我们在下表中列出了最先进的 GQA 实现及其功能,如表 1 所示。

表 1 最先进的 GQA 实现

实现表示BF16 GQA融合 INT4 GQA
Flash-Decoding (Triton 实现)FD
Flash Attention (v2.3.3)FA
CUDA 基线CU

除了 CU 之外,所有实现都支持 split-K 和非 split-K。CU 只有 split-K 实现。只有 FA 在后端有启发式方法来确定运行 split-K 还是非 split-K 内核。对于其他实现,用户必须明确选择要运行的版本。在这份说明中,我们关注长上下文长度(在我们的实验中,我们使用 8192 的上下文长度),因此尽可能选择 split-K 版本。

作为基线,我们测量了 NVIDIA A100 和 H100 GPU 上最先进 GQA 实现的性能。延迟(微秒)和实现的带宽(GB/s)报告在表 2 中。请注意,我们运行了一系列 split-K(从 2 到 128 分割),并报告了每个实现的最佳性能。对于所有实验,我们使用 8192 的上下文长度。对于 INT4 GQA,我们使用行式量化(即量化组数 = 1)。

表 2 基线 GQA 性能

在 A100 上

时间 (us)BF16 GQAINT4 GQA
批次大小FDFACUFDFACU
32139133183137143
64245229335234257
128433555596432455
2568269771127815866
51216071670219415811659
有效带宽 (GB/s)BF16 GQAINT4 GQA
批次大小FDFACUFDFACU
329651012736262250
6410971175802305278
1281240968901331314
25613011100954351331
51213381287980362345

在 H100 上

时间 (us)BF16 GQAINT4 GQA
批次大小FDFACUFDFACU
3291901147096
64148146200113162
128271298361205294
256515499658389558
5121000101112607561066
有效带宽 (GB/s)BF16 GQAINT4 GQA
批次大小FDFACUFDFACU
32148114961178511371
64181518401345631443
128198218021487699487
256208721561634736513
512215021271706757537

首先,让我们讨论 BF16 GQA 的性能:在所有实现中,CU 的性能排名垫底。FD 和 FA 具有可比的性能。当批处理大小小于或等于 64 时,FA 使用 split-K 内核,性能略优于 FD。但是,当批处理大小大于 64 时,FD 表现更好。

同样的趋势也适用于 INT4 GQA。然而,我们没有测量 FA 的性能,因为它不支持 INT4 KV 缓存。FD 在所有情况下都优于 CU。

当比较 BF16 和 INT4 GQA 之间 FD 的延迟时,我们发现它们几乎相同。这表明 INT4 GQA 效率非常低,这可以通过 INT4 GQA 比 BF16 GQA 显著降低的可实现带宽进一步证实。在查看 CU 的性能时,同样的趋势也适用。

带有 Tensor Cores 的 CUDA INT4 GQA 实现

在本节中,我们简要介绍了我们的基线实现,即带有张量核心的 CUDA INT4 GQA(CU)。每个线程块仅处理一个 KV 头和来自一个输入提示的一组查询头。因此,每个线程块执行 `mm(softmax(mm(Q, KT) / sqrt(D)), V)`;请注意,执行的是 `mm` 而不是 `bmm`。此外,由于这是一个 split-K 实现,KV 缓存中的令牌在不同的线程块之间进行分割。请注意,每个线程块包含 4 个 warp(对于 NVIDIA A100 和 H100 GPU,每个 warp 包含 32 个线程)。每个线程块中的工作在 warp 之间进行分割。在每个 warp 内,我们使用 WMMA API 在张量核心上计算矩阵乘法。图 4 展示了 CU 中的工作划分。

Figure 4: CU work partitioning

图 4 CU 工作划分

优化带张量核心的 CUDA INT4 GQA 内核

在这份说明中,我们讨论了我们对带有张量核心的 INT4 GQA(CU)CUDA 实现所应用的优化。根据上一节中的 CI 分析,理想目标是将 INT4 GQA 性能提高 4 倍。请注意,当上下文长度较长时,查询大小与 KV 缓存大小相比可以忽略不计。

在我们的分析中,我们使用 NVIDIA Nsight Compute (NCU) 作为主要性能分析器。我们通用的瓶颈消除方法是最小化停顿周期。我们对 INT4 GQA 应用了 10 个优化,其中三个是针对 NVIDIA A100/H100 GPU 的。这些优化是众所周知的 CUDA 优化技术,可以推广到许多应用程序。

值得注意的是,我们选择优化 CUDA 实现而不是 Flash-Decoding 实现(基于 Triton)的原因是,使用 CUDA,我们可以更好地控制低级指令的生成方式。我们应用的许多优化技术,例如直接操作张量核心片段(优化 7-9),无法通过 Triton 完成,因为它不向开发人员公开低级细节。然而,这些优化可以集成到基于编译器的解决方案中,以使优化可用于更广泛的运算符,这确实是我们未来计划的一部分。

优化 1:展开 `K` 加载

问题分析

NCU 配置文件显示,在 `K` 加载期间,只有 2 个全局加载,随后在 `dequantize_permuted_int4` 处出现 内存停顿。内存停顿是长时间的记分板停顿,表示等待全局内存访问。这表明内核没有发出足够的内存加载

以隐藏全局加载延迟。内核发出数据加载,然后立即等待使用数据,导致全局加载延迟暴露。停顿如图 5 所示。

Figure 5: K loading before unrolling

图 5 展开前 K 加载(箭头指向的数字是全局内存等待导致的停顿周期)

解决方案

在基线实现中,我们使用 `uint32_t` 在单次加载中加载 8 个 INT4 `K` 值,并且在每次迭代中执行 2 次 `uint32_t` 加载,即 16 个 INT4 K 值。为了更好地隐藏全局加载延迟,我们在 `dequantize_permuted_int4` 中使用 `K` 值之前,发出 8 次 `uint32_t` 加载而不是两次。这允许编译器展开加载并重新排序指令以更好地隐藏全局加载延迟。图 6 显示了展开后 `K` 加载的 NCU 配置文件。比较图 5 和图 6,我们通过展开 `K` 加载有效地减少了停顿周期。

Figure 6: K loading after unrolling

图 6 展开后 K 加载(箭头指向的数字是全局内存等待导致的停顿周期)

结果

表 3 INT4 GQA 优化 1 的性能(行式量化)

批次大小时间 (us)带宽 (GB/s)加速比
FDCUFDCU与 FD 相比与 CU 基线相比
基准优化 1基准优化 1
321371431342622502671.021.07
642342572373052783020.991.09
1284324554223313143391.021.08
2568158668063513313551.011.07
5121581165915503623453691.021.07

优化 2:改进 `P` 类型转换 (FP32->BF16)

问题分析

由于 `softmax(bmm(Q, KT) / sqrt(D))` 的结果是 FP32(在图 3 中表示为 `P`),内核必须在将其馈送到下一个 `bmm` 计算之前将 `P` 从 FP32 转换为 BF16。内核通过将 FP32 数据从共享内存的一个位置复制到共享内存的另一个位置来执行 `P` 的 FP32 到 BF16 转换。这会导致共享内存访问期间的停顿(如图 7 所示),这可能是由 (1) 共享内存间接寻址;和 (2) 共享内存 bank 冲突引起的,因为每个线程访问一个 16 位元素(因此,两个线程可以同时访问同一个内存 bank)。

Figure 7: P type casting before Optimization 2

图 7 优化 2 之前的 `P` 类型转换(箭头指向的数字是共享内存等待导致的停顿周期)

解决方案

我们使用线程块中的所有线程进行就地类型转换。每个线程操作两个连续的元素,以避免在存储 BF16 时发生共享内存 bank 冲突。所有线程同时处理相同的头 (`h`),以保证转换的正确性。就地转换步骤如下:

  1. 每个线程将相同头部的 2 个 FP32 令牌元素从共享内存加载到寄存器中
  2. 调用 `__syncthreads()` 以确保每个线程完成数据读取
  3. 每个线程将其数据转换为 2 个 BF16 令牌元素,然后将结果存储到相同的共享内存中

我们对实现应用的一些优化包括:

  • 使用矢量类型(特别是 `nv_bfloat2`)
  • 展开数据加载/存储,即在调用 `__syncthreads()` 之前执行多次加载,并在 `__syncthreads()` 之后执行多次存储

此优化后,在 `P` 类型转换期间未观察到长时间停顿,如图 8 所示。

Figure 8: P type casting after Optimization 2

图 8 优化 2 后的 `P` 类型转换(箭头指向的数字是共享内存等待导致的停顿周期)

罪魁祸首

由于我们通过使用寄存器作为中间存储来展开数据加载/存储,每个线程的寄存器数量增加,导致占用率降低。

结果

表 4 INT4 GQA 优化 2 的性能(行式量化)

批次大小时间 (us)带宽 (GB/s)加速比
FDCUFDCU与 FD 相比与 CU 基线相比
基准优化 2基准优化 2
321371431262622502851.091.14
642342572213052783241.061.16
1284324553953313143621.091.15
2568158667493513313821.091.16
5121581165914353623453991.101.16

优化 3:移除 `QKT` 最大值计算的局部内存使用

问题分析

在 softmax 计算期间,内核必须为每个头计算最大 `QKT`。它使用一个临时“线程局部”存储来存储每个线程的最大 `QKT` 结果(每个头一个浮点值)。根据编译器,线程局部存储可以分配在寄存器(片上)或局部内存(片外 == 全局内存)上。不幸的是,在基线中,线程局部存储位于局部内存中,这比寄存器慢得多(如图 9 所示)。我们怀疑这是因为编译器无法在编译时确定线程局部存储的索引(因为内核中的头数 (`H`) 是一个运行时变量)。访问局部内存就像访问寄存器一样可能会损害内核的性能。

Figure 9: Local memory access during max QKT computation

图 9 `QKT` 最大值计算期间的局部内存访问

解决方案

我们意识到,每个线程不需要 `H`(头部数量)浮点数作为临时存储,因为每个线程只需计算一个头部的最大 `QKT`,而不是所有头部。因此,每个线程只需一个浮点数,可以轻松存储在寄存器中。为了累积 warp 之间的最大结果,我们使用共享内存。此优化消除了在最大 `QKT` 计算期间对局部内存的使用。

结果

表 5 INT4 GQA 优化 3 的性能(行式量化)

批次大小时间 (us)带宽 (GB/s)加速比
FDCUFDCU与 FD 相比与 CU 基线相比
基准优化 3基准优化 3
321371431192622503001.141.20
642342572063052783481.141.25
1284324553683313143891.171.24
2568158666963513314111.171.24
5121581165913383623454281.181.24

优化 4:移除行和的局部内存使用

问题分析

优化 3 类似,在 `softmax` 计算的行和计算过程中也观察到局部内存使用问题。由于局部内存位于芯片外部,将其作为访问寄存器一样访问可能会损害内核的性能。

解决方案:

我们对行和计算应用了与最大 `QKT` 计算相同的解决方案。即让每个线程只计算一个头的行和,这只需要每个线程一个浮点数。这消除了对局部内存的需求。

结果

表 6 INT4 GQA 优化 4 的性能(行式量化)

批次大小时间 (us)带宽 (GB/s)加速比
FDCUFDCU与 FD 相比与 CU 基线相比
基准优化 4基准优化 4
321371431182622503021.151.21
642342572043052783511.151.26
1284324553643313143931.191.25
2568158666883513314161.181.26
5121581165913283623454311.191.25

优化 5:为 `V` 加载添加预取

问题分析

当加载 `V` 时,观察到与 `K` 加载相同的问题。也就是说,内核发出数据加载,然后立即等待使用数据,导致全局加载延迟暴露。然而,当使用上述展开技术时,编译器将临时缓冲区分配到局部内存而不是寄存器,导致速度大大降低。

解决方案

我们对 `V` 加载采用数据预取技术。在消耗当前迭代值后,我们立即加载下一个迭代的 `V` 值。这使得数据加载可以与 `PK` 计算重叠,从而提高内核性能。

结果

表 7 INT4 GQA 优化 5 的性能(行式量化)

批次大小时间 (us)带宽 (GB/s)加速比
FDCUFDCU与 FD 相比与 CU 基线相比
基准优化 5基准优化 5
321371431092622503271.251.31
642342571943052783701.211.33
1284324553453313144141.251.32
2568158666493513314411.261.33
5121581165912443623454601.271.33

优化 6:添加分组 INT4 (组 = 4) 与矢量加载

问题分析

在此优化之前,CU 仅支持行式 INT4 量化。也就是说,每行中的每列共享相同的比例因子。每行的比例因子存储在每行的前 4 个字节中,如图 10 所示。在内核中,每个线程每次只加载一行。由于每行包含 68 个字节(4 个字节用于比例因子,64 个字节用于数据),因此无法保证每行都与任何矢量类型的大小对齐。因此,无法使用矢量加载来加载 KV 缓存。

Figure 10: The layout of each row of INT4 KV cache with row-wise quantization

图 10 采用行式量化的 INT4 KV 缓存每行的布局

解决方案

我们已经实现了对组式 INT4 量化的支持,组数为 4。在这种情况下,KV 缓存张量中每行的列被分成 4 个相等组。同一组内的列共享相同的量化/反量化比例因子。INT4 KV 缓存的数据布局如图 11 所示。所有组的比例因子都被序列化并存储在每行的开头。INT4 数据也进行序列化并放置在比例因子旁边。

由于每行的字节数现在变为 80 字节,我们可以使用矢量类型(在本例中为 `uint2`)来加载数据。(我们 使用 `uint4`,因为每个线程由于张量核心片段大小每次只加载 16 个 INT4。)矢量加载通常优于标量加载,因为它不会导致额外的字节加载。

Figure 11: The layout of each row of INT4 KV cache with row-wise quantization

图 11 采用行式量化的 INT4 KV 缓存每行的布局

结果

表 8 INT4 GQA 优化 6 的性能(行式量化)

批次大小时间 (us)带宽 (GB/s)加速比
FDCUFDCU与 FD 相比与 CU 基线相比
基准优化 6基准优化 6
321371431112622503221.231.29
642342571923052783721.221.34
1284324553463313144141.251.32
2568158666423513314461.271.35
5121581165912443623454601.271.33

表 9 INT4 GQA 优化 6 的性能(组式量化,组数 = 4)

批次大小时间 (us)带宽 (GB/s)加速比
FDCUDA_WMMAFDCUDA_WMMA与 FD 相比
优化 6优化 6
321291163253641.31
642191953854311.36
1283923474294841.39
2567196384685271.41
512137512254895501.43

优化 7:直接从 WMMA 片段计算最大 `QKT`(A100/H100 专用)

问题分析

我们观察到在最大 `QKT` 计算期间由于共享内存访问导致的大量停顿(表现为大量的短记分板停顿),如图 12 所示。

Figure 12: Stalls due to shared memory access during max QKT computation

图 12 `QKT` 最大值计算期间由于共享内存访问导致的停顿(箭头指向的数字是共享内存等待导致的停顿周期)

解决方案

我们通过直接从 WMMA 片段(即张量核心片段)计算 `QKT` 的最大值来绕过共享内存。WMMA 片段的布局特定于 GPU 架构。在此优化中,我们仅为 NVIDIA A100/H100 GPU 启用了此优化。其他 GPU 仍将使用共享内存进行 `QKT` 最大值计算。通过绕过共享内存,我们有效地消除了由共享内存访问引起的停顿。用于存储 `QKT` 结果的 `C` 片段的张量核心布局如图 13 所示。

Figure 13: C fragment (QKT storage) tensor core layout on A100/H100

图 13 A100/H100 上 `C` 片段 (`QKT` 存储) 张量核心布局

表 10 INT4 GQA 优化 7 的性能(行式量化)

批次大小时间 (us)带宽 (GB/s)加速比
FDCUFDCU与 FD 相比与 CU 基线相比
基准优化 7基准优化 7
321371431072622503331.271.33
642342571833052783911.281.40
1284324553333313144301.301.37
2568158666203513314611.311.40
5121581165912063623454751.311.38

表 11 INT4 GQA 优化 7 的性能(组式量化,组数 = 4)

批次大小时间 (us)带宽 (GB/s)加速比
FDCUDA_WMMAFDCUDA_WMMA与 FD 相比与 CUDA_WMMA 优化 6 相比
优化 6优化 7优化 6优化 7
321291161113253643801.171.04
642191951873854314491.171.04
1283923473334294845061.181.04
2567196386154685275471.171.04
5121375122511844895505691.161.03

优化 8:将 FP32->BF16 结果直接写入 `P` 片段(A100/H100 专用)

问题分析

在 `P` 片段的 FP32-BF16 转换过程中,内核从共享内存加载 FP32 数据,进行转换,然后将 BF16 数据存储回共享内存。此外,转换需要多次线程块同步 (`__syncthreads()`)。

解决方案

由于内核的数据分区设计,每个 warp 只对 `P` 片段执行一次遍历。因此,我们不必将转换结果写回共享内存以供将来使用。为了避免将 BF16 数据写入共享内存和线程块同步,我们让每个 warp 从共享内存加载 `P` WMMA 片段的 FP32 数据,进行转换,然后将 BF16 数据直接写入 `P` 片段。

请注意,此优化仅适用于 NVIDIA A100 和 H100 GPU,因为 WMMA 片段布局取决于架构。对于非 A100/H100 GPU,内核将回退到原始路径。

`P` 片段张量核心布局如图 14 所示。请注意,此布局特定于 NVIDIA A100/H100 GPU。

Figure 14: P fragment tensor core layout on A100/H100

图 14 A100/H100 上 `P` 片段张量核心布局

表 12 INT4 GQA 优化 8 的性能(行式量化)

批次大小时间 (us)带宽 (GB/s)加速比
FDCUFDCU与 FD 相比与 CU 基线相比
基准优化 8基准优化 8
321371431012622503531.351.41
642342571743052784101.341.47
1284324553173313144511.361.43
2568158665903513314851.381.47
5121581165911433623455011.381.45

表 13 INT4 GQA 优化 8 的性能(组式量化,组数 = 4)

批次大小时间 (us)带宽 (GB/s)加速比
FDCUDA_WMMAFDCUDA_WMMA与 FD 相比与 CUDA_WMMA 优化 6 相比
优化 6优化 8优化 6优化 8
321291161063253643961.221.09
642191951803854314671.211.08
1283923473194294845281.231.09
2567196385964685275651.211.07
5121375122511384895505911.211.08

优化 9:交错 P 共享内存布局(A100/H100 专用)

问题分析

我们在 `P` 加载期间观察到大量的共享内存 bank 冲突。bank 冲突的数量取决于内存访问步长。例如,对于 split-Ks = 32 和最大序列长度 = 8192,我们观察到 32 个 bank 中只有 4 个并行访问(内存访问步长 = 256)。从图 14 中,当所有线程访问元素 0 时,具有相同 `threadIdx.x % 4` 的线程访问相同的 bank。

Figure 15: P fragment in shared memory before swizzling

图 15 交错前的共享内存中的 P 片段

解决方案

我们以一种避免 bank 冲突的方式对共享内存中 `P` 加载/存储的布局进行混洗。换句话说,我们使用混洗布局存储 `QKT` 结果(`C` 片段)并加载它们(`P` 片段)。此外,我们不使用取决于每个线程块的令牌数量的原始内存访问步长,而是使用片段的列大小作为步长,该步长是常量。因此,`P` 片段的加载和存储始终是连续的。

C 和 P 片段的新布局如图 16 所示。使用新布局,可以保证 16 个 bank 并行访问,如图 17 所示。

Figure 16: The swizzled layouts of C and P fragments

图 16 C 和 P 片段的交错布局

Figure 17: P fragment in shared memory after swizzling

图 17 交错后共享内存中的 P 片段

表 14 INT4 GQA 优化 9 的性能(行式量化)

批次大小时间 (us)带宽 (GB/s)加速比
FDCUFDCU与 FD 相比与 CU 基线相比
基准优化 9基准优化 9
32137143982622503651.391.46
642342571673052784291.411.54
1284324552993313144791.451.52
2568158665493513315211.481.58
5121581165910603623455401.491.56

表 15 INT4 GQA 优化 9 的性能(组式量化,组数 = 4)

批次大小时间 (us)带宽 (GB/s)加速比
FDCUDA_WMMAFDCUDA_WMMA与 FD 相比与 CUDA_WMMA 优化 6 相比
优化 6优化 9优化 6优化 9
321291161053253644001.231.10
642191951743854314841.261.12
1283923473024294845581.301.15
2567196385604685276011.281.14
5121375122510654895506321.291.15

优化 10:填充共享内存以进行 INT4 反量化

问题分析

一旦内核从全局内存读取 INT4 `K` 或 `V` 缓存,它会执行反量化并将结果(BF16)存储在共享内存中。然后,BF16 数据从共享内存(通过 WMMA 接口)加载到 WMMA 片段。我们观察到 `K` 和 `V` 访问都存在大量的 bank 冲突。例如,对于 `K` 存储,32 个 bank 中只有 4 个并行访问。对于 `K` 加载,16 个 bank 并行访问。`V` 存储和加载也发生同样的情况。请参阅解决方案部分中的图。

解决方案

我们填充共享内存以减少 bank 冲突。具体来说,我们将每行填充 2。也就是说,`K` 的行步长变为 `F_K` + 2,`V` 的行步长变为 `F_N` + 2(`F_K` 和 `F_N` 分别是 `K` 和 `V` WMMA 片段的固定宽度)。通过此优化,我们能够将 bank 冲突减少 1.8 倍,如图 18 所示。

Figure 18: Bank conflicts before and after Optimization 10

图 18 优化 10 前后的 Bank 冲突

优化 10 之后,对于 `K` 存储,32 个 bank 并行访问(如图 19 所示),而对于 `K` 加载,29 个 bank 并行访问(如图 20 所示)。

Figure 19: K fragment store shared memory layout without and with padding

图 19 有填充和无填充的 K 片段存储共享内存布局

Figure 20: K fragment load shared memory layout without and with padding

图 20 有填充和无填充的 K 片段加载共享内存布局

表 16 INT4 GQA 优化 10 的性能(行式量化)

批次大小时间 (us)带宽 (GB/s)加速比
FDCUFDCU与 FD 相比与 CU 基线相比
基准优化 10基准优化 10
32137143942622503801.451.52
642342571513052784751.551.71
1284324552663313145381.631.71
2568158664893513315861.671.77
512158116599303623456161.701.79

表 17 INT4 GQA 优化 10 的性能(组式量化,组数 = 4)

批次大小时间 (us)带宽 (GB/s)加速比
FDCUDA_WMMAFDCUDA_WMMA与 FD 相比与 CUDA_WMMA 优化 6 相比
优化 6优化 10优化 6优化 10
32129116993253644251.311.17
642191951613854315231.361.21
1283923472824294845981.391.23
2567196385094685276621.411.25
512137512259654895506981.431.27

性能评估

微基准测试结果

我们还使用我们优化的内核评估了 BF16 GQA 性能(如表 19 所示)。对于 BF16,CU 的性能通常仍不如 FD 和 FA。这是预期的,因为我们的优化主要针对 INT4。

虽然 INT4 GQA 仍不如 BF16 GQA 高效(参见实现的带宽),但值得注意的是,当比较 FD BF16 GQA 性能与 CU INT4 GQA 性能时,我们可以看到 INT4 的延迟小于 BF16

表 19 CU 优化后 BF16 GQA 和 INT GQA 的性能

在 A100 上

时间 (us)BF16 GQAINT4 GQA
批次大小FDFACU 之前CU 之后FDFACU 之前CU 之后
3213913318316313714394
64245229335276234257151
128433555596517432455266
2568269771127999815866489
512160716702194187915811659930
有效带宽 (GB/s)BF16 GQAINT4 GQA
批次大小FDFACU 之前CU 之后FDFACU 之前CU 之后
329651012736824262250380
6410971175802972305278475
12812409689011039331314538
256130111009541075351331586
512133812879801144362345616

在 H100 上

时间 (us)BF16 GQAINT4 GQA
批次大小FDFACU 之前CU 之后FDFACU 之前CU 之后
329190114100709664
64148146200183113162101
128271298361308205294170
256515499658556389558306
51210001011126010667561066575
有效带宽 (GB/s)BF16 GQAINT4 GQA
批次大小FDFACU 之前CU 之后FDFACU 之前CU 之后
321481149611781341511371560
641815184013451470631443710
1281982180214871743699487844
2562087215616341934736513935
5122150212717062015757537996

端到端结果

我们在 8 个 H100 GPU 上评估了 Llama 2 70B 中优化的 INT4 GQA 内核。我们运行了模型的端到端,但仅报告了解码延迟。我们使用 FP8 FFN(前馈网络)以强调解码阶段的注意力性能。我们将批处理大小从 1 更改为 256,上下文长度从 2,048 (2K) 更改为 16,384 (16K)。端到端性能结果如下图所示。

Figure 21: Meta Llama 2 decode latency (ms) comparison

图 21 Meta Llama 2 解码延迟 (ms) 比较(BF16 GQA 在大批处理大小配置中内存不足)

代码

如果您感兴趣,请 在此处 查看我们的代码。如果您有任何问题,请随时在 GitHub 上提出问题,我们将很乐意提供帮助。欢迎您的贡献!