低精度 KV 缓存的高效解码分组查询注意力
引言
生成式人工智能凭借其像人类一样生成内容的能力席卷了全球。许多生成式人工智能工具都由大型语言模型(LLM)驱动,例如 Meta 的 Llama 模型和 OpenAI 的 ChatGPT。LLM 面临的主要挑战之一是支持大“上下文长度”(也称为“序列长度”)。上下文长度是指模型用于理解输入上下文并生成响应的令牌数量。更长的上下文长度通常意味着响应具有更高的精度和质量。然而,长上下文长度在计算和内存方面都是密集的。这主要是由于以下原因:
- 注意力层的计算复杂度与上下文长度成比例增长(增长率取决于注意力算法)。因此,在使用长上下文长度时,注意力层可能会成为瓶颈,尤其是在预填充阶段,此时注意力是计算密集型的。
- KV 缓存大小与上下文长度线性增长,从而对内存需求施加更高的压力,进而减慢已经内存受限的注意力解码。此外,由于内存容量有限,当 KV 缓存变大时,批处理大小会减小,这通常会导致吞吐量下降。
与其他上述问题相比,计算复杂度的增长更难解决。解决 KV 缓存大小增长问题的一种方法是使用低精度 KV 缓存。根据我们的实验,在 Meta Llama 2 推理的解码阶段,分组 INT4 量化在准确性方面与 BF16 KV 缓存提供了可比的结果。然而,我们没有观察到任何延迟改进,尽管注意力解码层读取的数据减少了 4 倍。这意味着 INT4 注意力在利用宝贵的 HBM 带宽方面比 BF16 注意力效率低 4 倍。
在这份说明中,我们讨论了我们对 INT4 GQA(分组查询注意力——我们在 LLM 推理阶段使用的注意力层)应用的 CUDA 优化,以将其性能在 NVIDIA A100 GPU 上提高高达 1.8 倍,在 NVIDIA H100 GPU 上提高高达 1.9 倍。
- 优化的 CUDA INT4 GQA 比 INT4 Flash-Decoding GQA(我们在上述实验中使用的性能最佳的 INT4 GQA)在 A100 上性能提高了 1.4 倍至 1.7 倍,在 H100 上性能提高了 1.09 倍至 1.3 倍。
- 优化的 CUDA INT4 GQA 比 BF16 Flash-Decoding GQA 在 A100 上性能提高了 1.5 倍至 1.7 倍,在 H100 上性能提高了 1.4 倍至 1.7 倍。
背景
用于 LLM 推理的 GQA
分组查询注意力(GQA)是多头注意力(MHA)的一种变体,其中每个 KV 缓存头在查询头组之间共享。我们的 LLM 推理在预填充和解码阶段都采用 GQA 作为注意力层,以减少 KV 缓存的容量需求。我们在推理中使用多个 GPU,其中 KV 缓存和查询头分布在不同的 GPU 上。每个 GPU 运行一个带有一个 KV 头和一组 Q 头的注意力层。因此,从单个 GPU 的角度来看,GQA 组件也可以描述为 MQA(多查询注意力)。
解码 GQA 的简化工作流程如图 1 所示。GQA 接受三个主要输入:输入查询(表示为 `Q`)、K 缓存(表示为 `K`)和 V 缓存(表示为 `V`)。我们当前的 GQA 推理对 `Q`、`K` 和 `V` 使用 BF16。
Q是一个形状为 (`B`, `1`, `HQ`, `D`) 的 4D BF16 张量K是一个形状为 (`B`, `Tmax`, `HKV`, `D`) 的 4D BF16 张量V是一个形状为 (`B`, `Tmax`, `HKV`, `D`) 的 4D BF16 张量
其中
B是批处理大小(输入提示的数量)HQ是查询头的数量HKV是 KV 头的数量(HQ必须能被HKV整除)Tmax是最大上下文长度D是头维度(固定为 128)
GQA 简单地表示为 `bmm(softmax(bmm(Q, KT) / sqrt(D)), V)`。这会产生一个与 `Q` 具有相同形状的 4D BF16 张量(表示为 `O`)作为单个输出张量。请注意,矩阵乘法使用 BF16 执行,但是累加和 `softmax` 以 FP32 进行。我们将此称为“BF16 GQA”,因为 KV 缓存是 BF16。

图 1 LLM 推理中 BF16 GQA 的简化工作流程
INT4 GQA
为了进一步减小 KV 缓存的大小,我们探索了使用 INT4 代替 BF16 作为 KV 缓存的可能性。我们通过计算 INT4 GQA 的计算强度(CI)并将其与 BF16 GQA 的计算强度进行比较来估计潜在的性能改进,因为 CI 代表每字节的 FLOPS。我们计算 `QKT` 和 `PV` 的 CI(如方程 1 所示),因为它们将 KV 缓存作为操作数。请注意,我们忽略了 `Q` 加载,因为它与 KV 缓存相比可以忽略不计。我们还忽略了任何不在全局内存上的中间数据加载/存储。因此,CI 仅考虑计算 FLOPS 和 KV 缓存加载。

方程 (1)
假设 `HQ` = 8 且 `HKV` = 1,BF16 KV 缓存的 CI 为 8,而 INT4 KV 缓存的 CI 为 32。CI 表明 BF16 和 INT4 GQA 都是内存受限的(A100 和 H100 的 BF16 张量核心的峰值 CI 分别为 312 TF / 2 TB/s = 141 和 990 TF / 3.35 TB/s = 269;请注意,这些 TF 数字不考虑稀疏性)。此外,使用 INT4 KV 缓存,我们应该期望性能比 BF16 GQA 提高高达 4 倍。
为了在 GQA 中启用 INT4 KV 缓存支持,我们可以在将 KV 缓存传递给 BF16 GQA 运算符之前,将其从 INT4 反量化为 BF16。然而,由于 KV 缓存通常很大,从/向全局内存复制可能代价高昂。此外,解码 GQA 是一种内存受限操作(内存单元比计算单元利用率高得多)。图 2 显示了 xFormers 中 FMHA CUTLASS BF16 GQA 内核的 NCU 配置文件,这是 GQA 最先进的实现之一。从图中可以明显看出,内存是瓶颈。

图 2 xFormers 中 FMHA CUTLASS BF16 内核的 NCU 配置文件
更有效的替代方案是将 INT4 反量化与 GQA 操作融合(如图 3 所示)。换句话说,让 GQA 直接读取 INT4 KV 缓存并在内核中执行 INT4 到 BF16 的转换。这种改变可能会减少 KV 缓存所需的全局内存读取量,从而可能导致延迟降低。我们称之为“INT4 GQA”。

图 3 融合 INT4 GQA 的工作流程
我们在下表中列出了最先进的 GQA 实现及其功能,如表 1 所示。
表 1 最先进的 GQA 实现
| 实现 | 表示 | BF16 GQA | 融合 INT4 GQA |
| Flash-Decoding (Triton 实现) | FD | 是 | 是 |
| Flash Attention (v2.3.3) | FA | 是 | 否 |
| CUDA 基线 | CU | 是 | 是 |
除了 CU 之外,所有实现都支持 split-K 和非 split-K。CU 只有 split-K 实现。只有 FA 在后端有启发式方法来确定运行 split-K 还是非 split-K 内核。对于其他实现,用户必须明确选择要运行的版本。在这份说明中,我们关注长上下文长度(在我们的实验中,我们使用 8192 的上下文长度),因此尽可能选择 split-K 版本。
作为基线,我们测量了 NVIDIA A100 和 H100 GPU 上最先进 GQA 实现的性能。延迟(微秒)和实现的带宽(GB/s)报告在表 2 中。请注意,我们运行了一系列 split-K(从 2 到 128 分割),并报告了每个实现的最佳性能。对于所有实验,我们使用 8192 的上下文长度。对于 INT4 GQA,我们使用行式量化(即量化组数 = 1)。
表 2 基线 GQA 性能
在 A100 上
| 时间 (us) | BF16 GQA | INT4 GQA | ||||
| 批次大小 | FD | FA | CU | FD | FA | CU |
| 32 | 139 | 133 | 183 | 137 | – | 143 |
| 64 | 245 | 229 | 335 | 234 | – | 257 |
| 128 | 433 | 555 | 596 | 432 | – | 455 |
| 256 | 826 | 977 | 1127 | 815 | – | 866 |
| 512 | 1607 | 1670 | 2194 | 1581 | – | 1659 |
| 有效带宽 (GB/s) | BF16 GQA | INT4 GQA | ||||
| 批次大小 | FD | FA | CU | FD | FA | CU |
| 32 | 965 | 1012 | 736 | 262 | – | 250 |
| 64 | 1097 | 1175 | 802 | 305 | – | 278 |
| 128 | 1240 | 968 | 901 | 331 | – | 314 |
| 256 | 1301 | 1100 | 954 | 351 | – | 331 |
| 512 | 1338 | 1287 | 980 | 362 | – | 345 |
在 H100 上
| 时间 (us) | BF16 GQA | INT4 GQA | ||||
| 批次大小 | FD | FA | CU | FD | FA | CU |
| 32 | 91 | 90 | 114 | 70 | – | 96 |
| 64 | 148 | 146 | 200 | 113 | – | 162 |
| 128 | 271 | 298 | 361 | 205 | – | 294 |
| 256 | 515 | 499 | 658 | 389 | – | 558 |
| 512 | 1000 | 1011 | 1260 | 756 | – | 1066 |
| 有效带宽 (GB/s) | BF16 GQA | INT4 GQA | ||||
| 批次大小 | FD | FA | CU | FD | FA | CU |
| 32 | 1481 | 1496 | 1178 | 511 | – | 371 |
| 64 | 1815 | 1840 | 1345 | 631 | – | 443 |
| 128 | 1982 | 1802 | 1487 | 699 | – | 487 |
| 256 | 2087 | 2156 | 1634 | 736 | – | 513 |
| 512 | 2150 | 2127 | 1706 | 757 | – | 537 |
首先,让我们讨论 BF16 GQA 的性能:在所有实现中,CU 的性能排名垫底。FD 和 FA 具有可比的性能。当批处理大小小于或等于 64 时,FA 使用 split-K 内核,性能略优于 FD。但是,当批处理大小大于 64 时,FD 表现更好。
同样的趋势也适用于 INT4 GQA。然而,我们没有测量 FA 的性能,因为它不支持 INT4 KV 缓存。FD 在所有情况下都优于 CU。
当比较 BF16 和 INT4 GQA 之间 FD 的延迟时,我们发现它们几乎相同。这表明 INT4 GQA 效率非常低,这可以通过 INT4 GQA 比 BF16 GQA 显著降低的可实现带宽进一步证实。在查看 CU 的性能时,同样的趋势也适用。
带有 Tensor Cores 的 CUDA INT4 GQA 实现
在本节中,我们简要介绍了我们的基线实现,即带有张量核心的 CUDA INT4 GQA(CU)。每个线程块仅处理一个 KV 头和来自一个输入提示的一组查询头。因此,每个线程块执行 `mm(softmax(mm(Q, KT) / sqrt(D)), V)`;请注意,执行的是 `mm` 而不是 `bmm`。此外,由于这是一个 split-K 实现,KV 缓存中的令牌在不同的线程块之间进行分割。请注意,每个线程块包含 4 个 warp(对于 NVIDIA A100 和 H100 GPU,每个 warp 包含 32 个线程)。每个线程块中的工作在 warp 之间进行分割。在每个 warp 内,我们使用 WMMA API 在张量核心上计算矩阵乘法。图 4 展示了 CU 中的工作划分。

图 4 CU 工作划分
优化带张量核心的 CUDA INT4 GQA 内核
在这份说明中,我们讨论了我们对带有张量核心的 INT4 GQA(CU)CUDA 实现所应用的优化。根据上一节中的 CI 分析,理想目标是将 INT4 GQA 性能提高 4 倍。请注意,当上下文长度较长时,查询大小与 KV 缓存大小相比可以忽略不计。
在我们的分析中,我们使用 NVIDIA Nsight Compute (NCU) 作为主要性能分析器。我们通用的瓶颈消除方法是最小化停顿周期。我们对 INT4 GQA 应用了 10 个优化,其中三个是针对 NVIDIA A100/H100 GPU 的。这些优化是众所周知的 CUDA 优化技术,可以推广到许多应用程序。
值得注意的是,我们选择优化 CUDA 实现而不是 Flash-Decoding 实现(基于 Triton)的原因是,使用 CUDA,我们可以更好地控制低级指令的生成方式。我们应用的许多优化技术,例如直接操作张量核心片段(优化 7-9),无法通过 Triton 完成,因为它不向开发人员公开低级细节。然而,这些优化可以集成到基于编译器的解决方案中,以使优化可用于更广泛的运算符,这确实是我们未来计划的一部分。
优化 1:展开 `K` 加载
问题分析
NCU 配置文件显示,在 `K` 加载期间,只有 2 个全局加载,随后在 `dequantize_permuted_int4` 处出现 内存停顿。内存停顿是长时间的记分板停顿,表示等待全局内存访问。这表明内核没有发出足够的内存加载
以隐藏全局加载延迟。内核发出数据加载,然后立即等待使用数据,导致全局加载延迟暴露。停顿如图 5 所示。

图 5 展开前 K 加载(箭头指向的数字是全局内存等待导致的停顿周期)
解决方案
在基线实现中,我们使用 `uint32_t` 在单次加载中加载 8 个 INT4 `K` 值,并且在每次迭代中执行 2 次 `uint32_t` 加载,即 16 个 INT4 K 值。为了更好地隐藏全局加载延迟,我们在 `dequantize_permuted_int4` 中使用 `K` 值之前,发出 8 次 `uint32_t` 加载而不是两次。这允许编译器展开加载并重新排序指令以更好地隐藏全局加载延迟。图 6 显示了展开后 `K` 加载的 NCU 配置文件。比较图 5 和图 6,我们通过展开 `K` 加载有效地减少了停顿周期。

图 6 展开后 K 加载(箭头指向的数字是全局内存等待导致的停顿周期)
结果
表 3 INT4 GQA 优化 1 的性能(行式量化)
| 批次大小 | 时间 (us) | 带宽 (GB/s) | 加速比 | |||||
| FD | CU | FD | CU | 与 FD 相比 | 与 CU 基线相比 | |||
| 基准 | 优化 1 | 基准 | 优化 1 | |||||
| 32 | 137 | 143 | 134 | 262 | 250 | 267 | 1.02 | 1.07 |
| 64 | 234 | 257 | 237 | 305 | 278 | 302 | 0.99 | 1.09 |
| 128 | 432 | 455 | 422 | 331 | 314 | 339 | 1.02 | 1.08 |
| 256 | 815 | 866 | 806 | 351 | 331 | 355 | 1.01 | 1.07 |
| 512 | 1581 | 1659 | 1550 | 362 | 345 | 369 | 1.02 | 1.07 |
优化 2:改进 `P` 类型转换 (FP32->BF16)
问题分析
由于 `softmax(bmm(Q, KT) / sqrt(D))` 的结果是 FP32(在图 3 中表示为 `P`),内核必须在将其馈送到下一个 `bmm` 计算之前将 `P` 从 FP32 转换为 BF16。内核通过将 FP32 数据从共享内存的一个位置复制到共享内存的另一个位置来执行 `P` 的 FP32 到 BF16 转换。这会导致共享内存访问期间的停顿(如图 7 所示),这可能是由 (1) 共享内存间接寻址;和 (2) 共享内存 bank 冲突引起的,因为每个线程访问一个 16 位元素(因此,两个线程可以同时访问同一个内存 bank)。

图 7 优化 2 之前的 `P` 类型转换(箭头指向的数字是共享内存等待导致的停顿周期)
解决方案
我们使用线程块中的所有线程进行就地类型转换。每个线程操作两个连续的元素,以避免在存储 BF16 时发生共享内存 bank 冲突。所有线程同时处理相同的头 (`h`),以保证转换的正确性。就地转换步骤如下:
- 每个线程将相同头部的 2 个 FP32 令牌元素从共享内存加载到寄存器中
- 调用 `__syncthreads()` 以确保每个线程完成数据读取
- 每个线程将其数据转换为 2 个 BF16 令牌元素,然后将结果存储到相同的共享内存中
我们对实现应用的一些优化包括:
- 使用矢量类型(特别是 `nv_bfloat2`)
- 展开数据加载/存储,即在调用 `__syncthreads()` 之前执行多次加载,并在 `__syncthreads()` 之后执行多次存储
此优化后,在 `P` 类型转换期间未观察到长时间停顿,如图 8 所示。

图 8 优化 2 后的 `P` 类型转换(箭头指向的数字是共享内存等待导致的停顿周期)
罪魁祸首
由于我们通过使用寄存器作为中间存储来展开数据加载/存储,每个线程的寄存器数量增加,导致占用率降低。
结果
表 4 INT4 GQA 优化 2 的性能(行式量化)
| 批次大小 | 时间 (us) | 带宽 (GB/s) | 加速比 | |||||
| FD | CU | FD | CU | 与 FD 相比 | 与 CU 基线相比 | |||
| 基准 | 优化 2 | 基准 | 优化 2 | |||||
| 32 | 137 | 143 | 126 | 262 | 250 | 285 | 1.09 | 1.14 |
| 64 | 234 | 257 | 221 | 305 | 278 | 324 | 1.06 | 1.16 |
| 128 | 432 | 455 | 395 | 331 | 314 | 362 | 1.09 | 1.15 |
| 256 | 815 | 866 | 749 | 351 | 331 | 382 | 1.09 | 1.16 |
| 512 | 1581 | 1659 | 1435 | 362 | 345 | 399 | 1.10 | 1.16 |
优化 3:移除 `QKT` 最大值计算的局部内存使用
问题分析
在 softmax 计算期间,内核必须为每个头计算最大 `QKT`。它使用一个临时“线程局部”存储来存储每个线程的最大 `QKT` 结果(每个头一个浮点值)。根据编译器,线程局部存储可以分配在寄存器(片上)或局部内存(片外 == 全局内存)上。不幸的是,在基线中,线程局部存储位于局部内存中,这比寄存器慢得多(如图 9 所示)。我们怀疑这是因为编译器无法在编译时确定线程局部存储的索引(因为内核中的头数 (`H`) 是一个运行时变量)。访问局部内存就像访问寄存器一样可能会损害内核的性能。

图 9 `QKT` 最大值计算期间的局部内存访问
解决方案
我们意识到,每个线程不需要 `H`(头部数量)浮点数作为临时存储,因为每个线程只需计算一个头部的最大 `QKT`,而不是所有头部。因此,每个线程只需一个浮点数,可以轻松存储在寄存器中。为了累积 warp 之间的最大结果,我们使用共享内存。此优化消除了在最大 `QKT` 计算期间对局部内存的使用。
结果
表 5 INT4 GQA 优化 3 的性能(行式量化)
| 批次大小 | 时间 (us) | 带宽 (GB/s) | 加速比 | |||||
| FD | CU | FD | CU | 与 FD 相比 | 与 CU 基线相比 | |||
| 基准 | 优化 3 | 基准 | 优化 3 | |||||
| 32 | 137 | 143 | 119 | 262 | 250 | 300 | 1.14 | 1.20 |
| 64 | 234 | 257 | 206 | 305 | 278 | 348 | 1.14 | 1.25 |
| 128 | 432 | 455 | 368 | 331 | 314 | 389 | 1.17 | 1.24 |
| 256 | 815 | 866 | 696 | 351 | 331 | 411 | 1.17 | 1.24 |
| 512 | 1581 | 1659 | 1338 | 362 | 345 | 428 | 1.18 | 1.24 |
优化 4:移除行和的局部内存使用
问题分析
与 优化 3 类似,在 `softmax` 计算的行和计算过程中也观察到局部内存使用问题。由于局部内存位于芯片外部,将其作为访问寄存器一样访问可能会损害内核的性能。
解决方案:
我们对行和计算应用了与最大 `QKT` 计算相同的解决方案。即让每个线程只计算一个头的行和,这只需要每个线程一个浮点数。这消除了对局部内存的需求。
结果
表 6 INT4 GQA 优化 4 的性能(行式量化)
| 批次大小 | 时间 (us) | 带宽 (GB/s) | 加速比 | |||||
| FD | CU | FD | CU | 与 FD 相比 | 与 CU 基线相比 | |||
| 基准 | 优化 4 | 基准 | 优化 4 | |||||
| 32 | 137 | 143 | 118 | 262 | 250 | 302 | 1.15 | 1.21 |
| 64 | 234 | 257 | 204 | 305 | 278 | 351 | 1.15 | 1.26 |
| 128 | 432 | 455 | 364 | 331 | 314 | 393 | 1.19 | 1.25 |
| 256 | 815 | 866 | 688 | 351 | 331 | 416 | 1.18 | 1.26 |
| 512 | 1581 | 1659 | 1328 | 362 | 345 | 431 | 1.19 | 1.25 |
优化 5:为 `V` 加载添加预取
问题分析
当加载 `V` 时,观察到与 `K` 加载相同的问题。也就是说,内核发出数据加载,然后立即等待使用数据,导致全局加载延迟暴露。然而,当使用上述展开技术时,编译器将临时缓冲区分配到局部内存而不是寄存器,导致速度大大降低。
解决方案
我们对 `V` 加载采用数据预取技术。在消耗当前迭代值后,我们立即加载下一个迭代的 `V` 值。这使得数据加载可以与 `PK` 计算重叠,从而提高内核性能。
结果
表 7 INT4 GQA 优化 5 的性能(行式量化)
| 批次大小 | 时间 (us) | 带宽 (GB/s) | 加速比 | |||||
| FD | CU | FD | CU | 与 FD 相比 | 与 CU 基线相比 | |||
| 基准 | 优化 5 | 基准 | 优化 5 | |||||
| 32 | 137 | 143 | 109 | 262 | 250 | 327 | 1.25 | 1.31 |
| 64 | 234 | 257 | 194 | 305 | 278 | 370 | 1.21 | 1.33 |
| 128 | 432 | 455 | 345 | 331 | 314 | 414 | 1.25 | 1.32 |
| 256 | 815 | 866 | 649 | 351 | 331 | 441 | 1.26 | 1.33 |
| 512 | 1581 | 1659 | 1244 | 362 | 345 | 460 | 1.27 | 1.33 |
优化 6:添加分组 INT4 (组 = 4) 与矢量加载
问题分析
在此优化之前,CU 仅支持行式 INT4 量化。也就是说,每行中的每列共享相同的比例因子。每行的比例因子存储在每行的前 4 个字节中,如图 10 所示。在内核中,每个线程每次只加载一行。由于每行包含 68 个字节(4 个字节用于比例因子,64 个字节用于数据),因此无法保证每行都与任何矢量类型的大小对齐。因此,无法使用矢量加载来加载 KV 缓存。

图 10 采用行式量化的 INT4 KV 缓存每行的布局
解决方案
我们已经实现了对组式 INT4 量化的支持,组数为 4。在这种情况下,KV 缓存张量中每行的列被分成 4 个相等组。同一组内的列共享相同的量化/反量化比例因子。INT4 KV 缓存的数据布局如图 11 所示。所有组的比例因子都被序列化并存储在每行的开头。INT4 数据也进行序列化并放置在比例因子旁边。
由于每行的字节数现在变为 80 字节,我们可以使用矢量类型(在本例中为 `uint2`)来加载数据。(我们 不 使用 `uint4`,因为每个线程由于张量核心片段大小每次只加载 16 个 INT4。)矢量加载通常优于标量加载,因为它不会导致额外的字节加载。

图 11 采用行式量化的 INT4 KV 缓存每行的布局
结果
表 8 INT4 GQA 优化 6 的性能(行式量化)
| 批次大小 | 时间 (us) | 带宽 (GB/s) | 加速比 | |||||
| FD | CU | FD | CU | 与 FD 相比 | 与 CU 基线相比 | |||
| 基准 | 优化 6 | 基准 | 优化 6 | |||||
| 32 | 137 | 143 | 111 | 262 | 250 | 322 | 1.23 | 1.29 |
| 64 | 234 | 257 | 192 | 305 | 278 | 372 | 1.22 | 1.34 |
| 128 | 432 | 455 | 346 | 331 | 314 | 414 | 1.25 | 1.32 |
| 256 | 815 | 866 | 642 | 351 | 331 | 446 | 1.27 | 1.35 |
| 512 | 1581 | 1659 | 1244 | 362 | 345 | 460 | 1.27 | 1.33 |
表 9 INT4 GQA 优化 6 的性能(组式量化,组数 = 4)
| 批次大小 | 时间 (us) | 带宽 (GB/s) | 加速比 | ||
| FD | CUDA_WMMA | FD | CUDA_WMMA | 与 FD 相比 | |
| 优化 6 | 优化 6 | ||||
| 32 | 129 | 116 | 325 | 364 | 1.31 |
| 64 | 219 | 195 | 385 | 431 | 1.36 |
| 128 | 392 | 347 | 429 | 484 | 1.39 |
| 256 | 719 | 638 | 468 | 527 | 1.41 |
| 512 | 1375 | 1225 | 489 | 550 | 1.43 |
优化 7:直接从 WMMA 片段计算最大 `QKT`(A100/H100 专用)
问题分析
我们观察到在最大 `QKT` 计算期间由于共享内存访问导致的大量停顿(表现为大量的短记分板停顿),如图 12 所示。

图 12 `QKT` 最大值计算期间由于共享内存访问导致的停顿(箭头指向的数字是共享内存等待导致的停顿周期)
解决方案
我们通过直接从 WMMA 片段(即张量核心片段)计算 `QKT` 的最大值来绕过共享内存。WMMA 片段的布局特定于 GPU 架构。在此优化中,我们仅为 NVIDIA A100/H100 GPU 启用了此优化。其他 GPU 仍将使用共享内存进行 `QKT` 最大值计算。通过绕过共享内存,我们有效地消除了由共享内存访问引起的停顿。用于存储 `QKT` 结果的 `C` 片段的张量核心布局如图 13 所示。

图 13 A100/H100 上 `C` 片段 (`QKT` 存储) 张量核心布局
表 10 INT4 GQA 优化 7 的性能(行式量化)
| 批次大小 | 时间 (us) | 带宽 (GB/s) | 加速比 | |||||
| FD | CU | FD | CU | 与 FD 相比 | 与 CU 基线相比 | |||
| 基准 | 优化 7 | 基准 | 优化 7 | |||||
| 32 | 137 | 143 | 107 | 262 | 250 | 333 | 1.27 | 1.33 |
| 64 | 234 | 257 | 183 | 305 | 278 | 391 | 1.28 | 1.40 |
| 128 | 432 | 455 | 333 | 331 | 314 | 430 | 1.30 | 1.37 |
| 256 | 815 | 866 | 620 | 351 | 331 | 461 | 1.31 | 1.40 |
| 512 | 1581 | 1659 | 1206 | 362 | 345 | 475 | 1.31 | 1.38 |
表 11 INT4 GQA 优化 7 的性能(组式量化,组数 = 4)
| 批次大小 | 时间 (us) | 带宽 (GB/s) | 加速比 | |||||
| FD | CUDA_WMMA | FD | CUDA_WMMA | 与 FD 相比 | 与 CUDA_WMMA 优化 6 相比 | |||
| 优化 6 | 优化 7 | 优化 6 | 优化 7 | |||||
| 32 | 129 | 116 | 111 | 325 | 364 | 380 | 1.17 | 1.04 |
| 64 | 219 | 195 | 187 | 385 | 431 | 449 | 1.17 | 1.04 |
| 128 | 392 | 347 | 333 | 429 | 484 | 506 | 1.18 | 1.04 |
| 256 | 719 | 638 | 615 | 468 | 527 | 547 | 1.17 | 1.04 |
| 512 | 1375 | 1225 | 1184 | 489 | 550 | 569 | 1.16 | 1.03 |
优化 8:将 FP32->BF16 结果直接写入 `P` 片段(A100/H100 专用)
问题分析
在 `P` 片段的 FP32-BF16 转换过程中,内核从共享内存加载 FP32 数据,进行转换,然后将 BF16 数据存储回共享内存。此外,转换需要多次线程块同步 (`__syncthreads()`)。
解决方案
由于内核的数据分区设计,每个 warp 只对 `P` 片段执行一次遍历。因此,我们不必将转换结果写回共享内存以供将来使用。为了避免将 BF16 数据写入共享内存和线程块同步,我们让每个 warp 从共享内存加载 `P` WMMA 片段的 FP32 数据,进行转换,然后将 BF16 数据直接写入 `P` 片段。
请注意,此优化仅适用于 NVIDIA A100 和 H100 GPU,因为 WMMA 片段布局取决于架构。对于非 A100/H100 GPU,内核将回退到原始路径。
`P` 片段张量核心布局如图 14 所示。请注意,此布局特定于 NVIDIA A100/H100 GPU。

图 14 A100/H100 上 `P` 片段张量核心布局
表 12 INT4 GQA 优化 8 的性能(行式量化)
| 批次大小 | 时间 (us) | 带宽 (GB/s) | 加速比 | |||||
| FD | CU | FD | CU | 与 FD 相比 | 与 CU 基线相比 | |||
| 基准 | 优化 8 | 基准 | 优化 8 | |||||
| 32 | 137 | 143 | 101 | 262 | 250 | 353 | 1.35 | 1.41 |
| 64 | 234 | 257 | 174 | 305 | 278 | 410 | 1.34 | 1.47 |
| 128 | 432 | 455 | 317 | 331 | 314 | 451 | 1.36 | 1.43 |
| 256 | 815 | 866 | 590 | 351 | 331 | 485 | 1.38 | 1.47 |
| 512 | 1581 | 1659 | 1143 | 362 | 345 | 501 | 1.38 | 1.45 |
表 13 INT4 GQA 优化 8 的性能(组式量化,组数 = 4)
| 批次大小 | 时间 (us) | 带宽 (GB/s) | 加速比 | |||||
| FD | CUDA_WMMA | FD | CUDA_WMMA | 与 FD 相比 | 与 CUDA_WMMA 优化 6 相比 | |||
| 优化 6 | 优化 8 | 优化 6 | 优化 8 | |||||
| 32 | 129 | 116 | 106 | 325 | 364 | 396 | 1.22 | 1.09 |
| 64 | 219 | 195 | 180 | 385 | 431 | 467 | 1.21 | 1.08 |
| 128 | 392 | 347 | 319 | 429 | 484 | 528 | 1.23 | 1.09 |
| 256 | 719 | 638 | 596 | 468 | 527 | 565 | 1.21 | 1.07 |
| 512 | 1375 | 1225 | 1138 | 489 | 550 | 591 | 1.21 | 1.08 |
优化 9:交错 P 共享内存布局(A100/H100 专用)
问题分析
我们在 `P` 加载期间观察到大量的共享内存 bank 冲突。bank 冲突的数量取决于内存访问步长。例如,对于 split-Ks = 32 和最大序列长度 = 8192,我们观察到 32 个 bank 中只有 4 个并行访问(内存访问步长 = 256)。从图 14 中,当所有线程访问元素 0 时,具有相同 `threadIdx.x % 4` 的线程访问相同的 bank。

图 15 交错前的共享内存中的 P 片段
解决方案
我们以一种避免 bank 冲突的方式对共享内存中 `P` 加载/存储的布局进行混洗。换句话说,我们使用混洗布局存储 `QKT` 结果(`C` 片段)并加载它们(`P` 片段)。此外,我们不使用取决于每个线程块的令牌数量的原始内存访问步长,而是使用片段的列大小作为步长,该步长是常量。因此,`P` 片段的加载和存储始终是连续的。
C 和 P 片段的新布局如图 16 所示。使用新布局,可以保证 16 个 bank 并行访问,如图 17 所示。

图 16 C 和 P 片段的交错布局

图 17 交错后共享内存中的 P 片段
表 14 INT4 GQA 优化 9 的性能(行式量化)
| 批次大小 | 时间 (us) | 带宽 (GB/s) | 加速比 | |||||
| FD | CU | FD | CU | 与 FD 相比 | 与 CU 基线相比 | |||
| 基准 | 优化 9 | 基准 | 优化 9 | |||||
| 32 | 137 | 143 | 98 | 262 | 250 | 365 | 1.39 | 1.46 |
| 64 | 234 | 257 | 167 | 305 | 278 | 429 | 1.41 | 1.54 |
| 128 | 432 | 455 | 299 | 331 | 314 | 479 | 1.45 | 1.52 |
| 256 | 815 | 866 | 549 | 351 | 331 | 521 | 1.48 | 1.58 |
| 512 | 1581 | 1659 | 1060 | 362 | 345 | 540 | 1.49 | 1.56 |
表 15 INT4 GQA 优化 9 的性能(组式量化,组数 = 4)
| 批次大小 | 时间 (us) | 带宽 (GB/s) | 加速比 | |||||
| FD | CUDA_WMMA | FD | CUDA_WMMA | 与 FD 相比 | 与 CUDA_WMMA 优化 6 相比 | |||
| 优化 6 | 优化 9 | 优化 6 | 优化 9 | |||||
| 32 | 129 | 116 | 105 | 325 | 364 | 400 | 1.23 | 1.10 |
| 64 | 219 | 195 | 174 | 385 | 431 | 484 | 1.26 | 1.12 |
| 128 | 392 | 347 | 302 | 429 | 484 | 558 | 1.30 | 1.15 |
| 256 | 719 | 638 | 560 | 468 | 527 | 601 | 1.28 | 1.14 |
| 512 | 1375 | 1225 | 1065 | 489 | 550 | 632 | 1.29 | 1.15 |
优化 10:填充共享内存以进行 INT4 反量化
问题分析
一旦内核从全局内存读取 INT4 `K` 或 `V` 缓存,它会执行反量化并将结果(BF16)存储在共享内存中。然后,BF16 数据从共享内存(通过 WMMA 接口)加载到 WMMA 片段。我们观察到 `K` 和 `V` 访问都存在大量的 bank 冲突。例如,对于 `K` 存储,32 个 bank 中只有 4 个并行访问。对于 `K` 加载,16 个 bank 并行访问。`V` 存储和加载也发生同样的情况。请参阅解决方案部分中的图。
解决方案
我们填充共享内存以减少 bank 冲突。具体来说,我们将每行填充 2。也就是说,`K` 的行步长变为 `F_K` + 2,`V` 的行步长变为 `F_N` + 2(`F_K` 和 `F_N` 分别是 `K` 和 `V` WMMA 片段的固定宽度)。通过此优化,我们能够将 bank 冲突减少 1.8 倍,如图 18 所示。

图 18 优化 10 前后的 Bank 冲突
优化 10 之后,对于 `K` 存储,32 个 bank 并行访问(如图 19 所示),而对于 `K` 加载,29 个 bank 并行访问(如图 20 所示)。

图 19 有填充和无填充的 K 片段存储共享内存布局

图 20 有填充和无填充的 K 片段加载共享内存布局
表 16 INT4 GQA 优化 10 的性能(行式量化)
| 批次大小 | 时间 (us) | 带宽 (GB/s) | 加速比 | |||||
| FD | CU | FD | CU | 与 FD 相比 | 与 CU 基线相比 | |||
| 基准 | 优化 10 | 基准 | 优化 10 | |||||
| 32 | 137 | 143 | 94 | 262 | 250 | 380 | 1.45 | 1.52 |
| 64 | 234 | 257 | 151 | 305 | 278 | 475 | 1.55 | 1.71 |
| 128 | 432 | 455 | 266 | 331 | 314 | 538 | 1.63 | 1.71 |
| 256 | 815 | 866 | 489 | 351 | 331 | 586 | 1.67 | 1.77 |
| 512 | 1581 | 1659 | 930 | 362 | 345 | 616 | 1.70 | 1.79 |
表 17 INT4 GQA 优化 10 的性能(组式量化,组数 = 4)
| 批次大小 | 时间 (us) | 带宽 (GB/s) | 加速比 | |||||
| FD | CUDA_WMMA | FD | CUDA_WMMA | 与 FD 相比 | 与 CUDA_WMMA 优化 6 相比 | |||
| 优化 6 | 优化 10 | 优化 6 | 优化 10 | |||||
| 32 | 129 | 116 | 99 | 325 | 364 | 425 | 1.31 | 1.17 |
| 64 | 219 | 195 | 161 | 385 | 431 | 523 | 1.36 | 1.21 |
| 128 | 392 | 347 | 282 | 429 | 484 | 598 | 1.39 | 1.23 |
| 256 | 719 | 638 | 509 | 468 | 527 | 662 | 1.41 | 1.25 |
| 512 | 1375 | 1225 | 965 | 489 | 550 | 698 | 1.43 | 1.27 |
性能评估
微基准测试结果
我们还使用我们优化的内核评估了 BF16 GQA 性能(如表 19 所示)。对于 BF16,CU 的性能通常仍不如 FD 和 FA。这是预期的,因为我们的优化主要针对 INT4。
虽然 INT4 GQA 仍不如 BF16 GQA 高效(参见实现的带宽),但值得注意的是,当比较 FD BF16 GQA 性能与 CU INT4 GQA 性能时,我们可以看到 INT4 的延迟小于 BF16。
表 19 CU 优化后 BF16 GQA 和 INT GQA 的性能
在 A100 上
| 时间 (us) | BF16 GQA | INT4 GQA | ||||||
| 批次大小 | FD | FA | CU 之前 | CU 之后 | FD | FA | CU 之前 | CU 之后 |
| 32 | 139 | 133 | 183 | 163 | 137 | – | 143 | 94 |
| 64 | 245 | 229 | 335 | 276 | 234 | – | 257 | 151 |
| 128 | 433 | 555 | 596 | 517 | 432 | – | 455 | 266 |
| 256 | 826 | 977 | 1127 | 999 | 815 | – | 866 | 489 |
| 512 | 1607 | 1670 | 2194 | 1879 | 1581 | – | 1659 | 930 |
| 有效带宽 (GB/s) | BF16 GQA | INT4 GQA | ||||||
| 批次大小 | FD | FA | CU 之前 | CU 之后 | FD | FA | CU 之前 | CU 之后 |
| 32 | 965 | 1012 | 736 | 824 | 262 | – | 250 | 380 |
| 64 | 1097 | 1175 | 802 | 972 | 305 | – | 278 | 475 |
| 128 | 1240 | 968 | 901 | 1039 | 331 | – | 314 | 538 |
| 256 | 1301 | 1100 | 954 | 1075 | 351 | – | 331 | 586 |
| 512 | 1338 | 1287 | 980 | 1144 | 362 | – | 345 | 616 |
在 H100 上
| 时间 (us) | BF16 GQA | INT4 GQA | ||||||
| 批次大小 | FD | FA | CU 之前 | CU 之后 | FD | FA | CU 之前 | CU 之后 |
| 32 | 91 | 90 | 114 | 100 | 70 | – | 96 | 64 |
| 64 | 148 | 146 | 200 | 183 | 113 | – | 162 | 101 |
| 128 | 271 | 298 | 361 | 308 | 205 | – | 294 | 170 |
| 256 | 515 | 499 | 658 | 556 | 389 | – | 558 | 306 |
| 512 | 1000 | 1011 | 1260 | 1066 | 756 | – | 1066 | 575 |
| 有效带宽 (GB/s) | BF16 GQA | INT4 GQA | ||||||
| 批次大小 | FD | FA | CU 之前 | CU 之后 | FD | FA | CU 之前 | CU 之后 |
| 32 | 1481 | 1496 | 1178 | 1341 | 511 | – | 371 | 560 |
| 64 | 1815 | 1840 | 1345 | 1470 | 631 | – | 443 | 710 |
| 128 | 1982 | 1802 | 1487 | 1743 | 699 | – | 487 | 844 |
| 256 | 2087 | 2156 | 1634 | 1934 | 736 | – | 513 | 935 |
| 512 | 2150 | 2127 | 1706 | 2015 | 757 | – | 537 | 996 |
端到端结果
我们在 8 个 H100 GPU 上评估了 Llama 2 70B 中优化的 INT4 GQA 内核。我们运行了模型的端到端,但仅报告了解码延迟。我们使用 FP8 FFN(前馈网络)以强调解码阶段的注意力性能。我们将批处理大小从 1 更改为 256,上下文长度从 2,048 (2K) 更改为 16,384 (16K)。端到端性能结果如下图所示。

图 21 Meta Llama 2 解码延迟 (ms) 比较(BF16 GQA 在大批处理大小配置中内存不足)
代码
如果您感兴趣,请 在此处 查看我们的代码。如果您有任何问题,请随时在 GitHub 上提出问题,我们将很乐意提供帮助。欢迎您的贡献!