采用低精度 KV 缓存的高效解码分组查询注意力
引言
生成式 AI 以其生成类人内容的能力风靡全球。许多此类生成式 AI 工具都由大型语言模型 (LLM) 提供支持,例如 Meta Llama 模型和 OpenAI 的 ChatGPT。LLM 的主要挑战之一是支持大“上下文长度”(也称为“序列长度”)。上下文长度是指模型用于理解输入上下文并生成响应的令牌数量。更长的上下文长度通常意味着更高的响应精度和质量。然而,长上下文长度计算和内存密集。这主要是由于以下原因:
- 注意力层的计算复杂度与上下文长度成比例增长(增长率取决于注意力算法)。因此,在使用长上下文长度时,注意力层可能成为瓶颈,尤其是在注意力受计算限制的预填充阶段。
- KV 缓存大小随上下文长度线性增长,从而对内存需求施加更大的压力,进而减慢已经受内存限制的注意力解码。此外,由于内存容量有限,当 KV 缓存变大时,批处理大小会减小,这通常会导致吞吐量下降。
与上述其他问题相比,计算复杂度增长更难解决。解决 KV 缓存大小增长问题的一种方法是使用低精度 KV 缓存。从我们的实验来看,与 Meta Llama 2 推理解码阶段的 BF16 KV 缓存相比,分组式 INT4 量化在精度方面提供了可比的结果。然而,尽管在注意力解码层读取的数据量减少了 4 倍,但我们并未观察到任何延迟改进。这意味着 INT4 注意力在利用宝贵的 HBM 带宽方面的效率比 BF16 注意力低 4 倍。
在本说明中,我们讨论了我们应用于 INT4 GQA(分组查询注意力——我们在 LLM 推理阶段使用的注意力层)的 CUDA 优化,以将其性能提高多达 <强>在 NVIDIA A100 GPU 上达到 1.8 倍强>,在 <强>NVIDIA H100 GPU 上达到 1.9 倍强>。
- <强>优化后的 CUDA INT4 GQA强> 在 A100 上比 INT4 Flash-Decoding GQA(我们在上述实验中使用的性能最佳的 INT4 GQA)快 <强>1.4 倍到 1.7 倍强>,在 H100 上快 <强>1.09 倍到 1.3 倍强>。
- <强>优化后的 CUDA INT4 GQA强> 在 A100 上比 <强>BF16 Flash-Decoding GQA强> 快 <强>1.5 倍到 1.7 倍强>,在 H100 上快 <强>1.4 倍到 1.7 倍强>。
背景
用于 LLM 推理的 GQA
分组查询注意力 (GQA) 是多头注意力 (MHA) 的一种变体,其中每个 KV 缓存头在查询头组之间共享。我们的 LLM 推理在预填充和解码阶段都采用 GQA 作为注意力层,以减少 KV 缓存的容量需求。我们在推理中使用多个 GPU,其中 KV 缓存和查询头分布在不同的 GPU 上。每个 GPU 运行一个注意力层,带有一个 KV 头和一组 Q 头。因此,从单个 GPU 的角度来看,GQA 组件也可以描述为 MQA (多查询注意力)。
图 1 描绘了解码 GQA 的简化工作流。GQA 接受三个主要输入:输入查询(表示为 Q
)、K 缓存(表示为 K
)和 V 缓存(表示为 V
)。我们当前的 GQA 推理使用 BF16 用于 Q
、K
和 V
。
Q
是一个形状为 (B
,1
,HQ
,D
) 的 4D BF16 张量K
是一个形状为 (B
,Tmax
,HKV
,D
) 的 4D BF16 张量V
是一个形状为 (B
,Tmax
,HKV
,D
) 的 4D BF16 张量
其中
B
是批处理大小(输入提示的数量)HQ
是查询头的数量HKV
是 KV 头的数量(HQ
必须能被HKV
整除)Tmax
是最大上下文长度D
是头维度(固定为 128)
GQA 只是 bmm(softmax(bmm(Q, KT) / sqrt(D)), V)
。这会产生一个输出张量(表示为 O
),它是一个与 Q
形状相同的 4D BF16 张量。请注意,矩阵乘法使用 BF16 执行,但是累加和 softmax
在 FP32 中执行。我们将此称为“BF16 GQA”,因为 KV 缓存是 BF16。

<强>图 1强> LLM 推理的 BF16 GQA 简化工作流
INT4 GQA
为了进一步减小 KV 缓存的大小,我们探索了使用 INT4 代替 BF16 作为 KV 缓存的可能性。我们通过计算 INT4 GQA 的计算强度 (CI) 并将其与 BF16 GQA 的计算强度进行比较来估计潜在的性能改进,因为 CI 表示每字节的浮点运算次数。我们计算 QKT
和 PV
的 CI(如公式 1 所示),因为它们将 KV 缓存作为操作数。请注意,我们忽略了 Q
加载,因为它与 KV 缓存相比可以忽略不计。我们还忽略了不在全局内存上的任何中间数据加载/存储。因此,CI 仅考虑计算 FLOPS 和 KV 缓存加载。

公式 (1)
假设 HQ
= 8 且 HKV
= 1,BF16 KV 缓存的 CI 为 8,而 INT4 KV 缓存的 CI 为 32。CI 表明 BF16 和 INT4 GQA 都是内存受限的(A100 和 H100 上 BF16 Tensor Cores 的峰值 CI 分别为 312 TF / 2 TB/s = 141 和 990 TF / 3.35 TB/s = 269;请注意,这些 TF 数字不包括稀疏性)。此外,使用 INT4 KV 缓存,我们预计性能将比 BF16 GQA 提高多达 4 倍。
为了在 GQA 中启用 INT4 KV 缓存支持,我们可以在将 KV 缓存传递给 BF16 GQA 运算符之前将其从 INT4 反量化为 BF16。然而,由于 KV 缓存通常很大,从/向全局内存复制可能代价高昂。此外,解码 GQA 是一个内存受限操作(内存单元的利用率远高于计算单元)。图 2 显示了 xFormers 中 FMHA CUTLASS BF16 GQA 内核 的 NCU 配置文件,这是 GQA 最先进的实现之一。从图中可以明显看出,内存是一个瓶颈。

<强>图 2强> xFormers 中 FMHA CUTLASS BF16 内核 的 NCU 配置文件
一种更有效的替代方法是将 INT4 反量化与 GQA 操作融合(如图 3 所示)。换句话说,让 GQA 直接读取 INT4 KV 缓存并在内核中执行 INT4 到 BF16 的转换。这种更改可能会减少 KV 缓存所需的全局内存读取量,从而可能导致延迟降低。我们称之为“INT4 GQA”。

<强>图 3强> 融合 INT4 GQA 的工作流
我们在下表中列出了最先进的 GQA 实现及其功能,如表 1 所示。
<强>表 1强> 最先进的 GQA 实现
实现 | 表示 | BF16 GQA | 融合 INT4 GQA |
Flash-Decoding(Triton 实现) | FD | 是 | 是 |
Flash Attention (v2.3.3) | FA | 是 | 否 |
CUDA 基线 | CU | 是 | 是 |
除了 CU,所有实现都支持 split-K 和非 split-K。CU 只有 split-K 实现。只有 FA 在后端有启发式算法来确定是运行 split-K 还是非 split-K 内核。对于其他实现,用户必须明确选择要运行的版本。在本说明中,我们关注长上下文长度(在我们的实验中,我们使用 8192 的上下文长度),因此尽可能选择 split-K 版本。
作为基准,我们测量了 NVIDIA A100 和 H100 GPU 上最先进的 GQA 实现的性能。表 2 报告了延迟(以微秒为单位)和实现的带宽(GB/s)。请注意,我们运行了一系列 split-K(从 2 到 128 个 split)并报告了每个实现的最佳性能。对于所有实验,我们使用 8192 的上下文长度。对于 INT4 GQA,我们使用行式量化(即量化组数 = 1)。
<强>表 2强> GQA 基线性能
在 A100 上
时间 (us) | BF16 GQA | INT4 GQA | ||||
批次大小 | FD | FA | CU | FD | FA | CU |
32 | 139 | 133 | 183 | 137 | – | 143 |
64 | 245 | 229 | 335 | 234 | – | 257 |
128 | 433 | 555 | 596 | 432 | – | 455 |
256 | 826 | 977 | 1127 | 815 | – | 866 |
512 | 1607 | 1670 | 2194 | 1581 | – | 1659 |
有效带宽 (GB/s) | BF16 GQA | INT4 GQA | ||||
批次大小 | FD | FA | CU | FD | FA | CU |
32 | 965 | 1012 | 736 | 262 | – | 250 |
64 | 1097 | 1175 | 802 | 305 | – | 278 |
128 | 1240 | 968 | 901 | 331 | – | 314 |
256 | 1301 | 1100 | 954 | 351 | – | 331 |
512 | 1338 | 1287 | 980 | 362 | – | 345 |
在 H100 上
时间 (us) | BF16 GQA | INT4 GQA | ||||
批次大小 | FD | FA | CU | FD | FA | CU |
32 | 91 | 90 | 114 | 70 | – | 96 |
64 | 148 | 146 | 200 | 113 | – | 162 |
128 | 271 | 298 | 361 | 205 | – | 294 |
256 | 515 | 499 | 658 | 389 | – | 558 |
512 | 1000 | 1011 | 1260 | 756 | – | 1066 |
有效带宽 (GB/s) | BF16 GQA | INT4 GQA | ||||
批次大小 | FD | FA | CU | FD | FA | CU |
32 | 1481 | 1496 | 1178 | 511 | – | 371 |
64 | 1815 | 1840 | 1345 | 631 | – | 443 |
128 | 1982 | 1802 | 1487 | 699 | – | 487 |
256 | 2087 | 2156 | 1634 | 736 | – | 513 |
512 | 2150 | 2127 | 1706 | 757 | – | 537 |
首先,让我们讨论一下 BF16 GQA 性能:在所有实现中,CU 的性能排名垫底。FD 和 FA 具有可比的性能。当批处理大小小于或等于 64 时,FA 使用 split-K 内核,并且性能略优于 FD。但是,当批处理大小大于 64 时,FD 的性能更好。
INT4 GQA 也是如此。然而,我们没有测量 FA 的性能,因为它不支持 INT4 KV 缓存。FD 在所有情况下都优于 CU。
当比较 BF16 和 INT4 GQA 之间 FD 的延迟时,我们发现它们几乎相同。这表明 <强>INT4 GQA 效率极低强>,这可以通过 INT4 GQA 比 BF16 GQA 显著降低的可用带宽进一步证实。CU 的性能也存在相同的趋势。
采用 Tensor Cores 的 INT4 GQA CUDA 实现
在本节中,我们简要描述了我们的基线实现,即带有 Tensor Cores 的 INT4 GQA CUDA (CU)。每个线程块只处理一个 KV 头和来自一个输入提示的一组查询头。因此,每个线程块执行 mm(softmax(mm(Q, KT) / sqrt(D)), V)
;请注意,这里执行的是 mm
而不是 bmm
。此外,由于这是一个 split-K 实现,KV 缓存中的令牌在不同的线程块之间拆分。请注意,每个线程块包含 4 个 warp(每个 warp 包含 32 个线程,适用于 NVIDIA A100 和 H100 GPU)。每个线程块中的工作在 warp 之间拆分。在每个 warp 中,我们使用 WMMA API 计算 Tensor Cores 上的矩阵乘法。图 4 展示了 CU 中的工作分区。

<强>图 4强> CU 工作分区
优化 INT4 GQA 的 CUDA Tensor Cores 内核
在本说明中,我们讨论了我们应用于带有 Tensor Cores 的 INT4 GQA (CU) CUDA 实现的优化。理想目标是根据上一节中的 CI 分析将 INT4 GQA 性能提高 4 倍。请注意,当上下文长度很长时,查询大小与 KV 缓存大小相比可以忽略不计。
在我们的分析中,我们使用 NVIDIA Nsight Compute (NCU) 作为主要分析器。我们一般的瓶颈消除方法是最小化停顿周期。我们对 INT4 GQA 应用了 10 项优化,其中三项是针对 NVIDIA A100/H100 GPU 的。这些优化是众所周知的 CUDA 优化技术,可以推广到许多应用程序。
值得注意的是,我们选择优化 CUDA 实现而不是 Flash-Decoding 实现(基于 Triton)的原因是,使用 CUDA,我们可以更好地控制低级指令的生成方式。我们应用的许多优化技术,例如直接在 Tensor Core 片段上操作(优化 7-9),无法通过 Triton 完成,因为它不向开发人员公开低级细节。然而,这些优化可以集成到基于编译器的解决方案中,以使这些优化可用于更广泛的运算符,这确实是我们未来计划的一部分。
优化 1:展开 K
加载
问题分析
NCU 分析显示,在 K
加载期间,只有 2 次全局加载,随后在 dequantize_permuted_int4
处出现<强>内存停顿强>。内存停顿是长时间的记分牌停顿,表明等待全局内存访问。这表明内核发出的内存加载不足
无法隐藏全局加载延迟。内核发出数据加载,然后立即等待使用数据,导致全局加载延迟暴露。停顿如图 5 所示。

<强>图 5强> 展开前的 K 加载(箭头指向的数字是全局内存等待造成的停顿周期)
解决方案
在基线实现中,我们使用 uint32_t
在单次加载中加载 8 个 INT4 K
值,并且在每次迭代中执行 2 次 uint32_t
加载,即 16 个 INT4 K 值。为了更好地隐藏全局加载延迟,我们在消耗 dequantize_permuted_int4
中的 K
值之前发出 8 次 uint32_t
加载而不是两次。这允许编译器展开加载并重新排序指令以更好地隐藏全局加载延迟。图 6 显示了展开后 K
加载的 NCU 配置文件。比较图 5 和图 6,我们通过展开 K
加载有效地减少了停顿周期。

<强>图 6强> 展开后的 K 加载(箭头指向的数字是全局内存等待造成的停顿周期)
结果
<强>表 3强> INT4 GQA 优化 1 的性能(行式量化)
批次大小 | 时间 (us) | 带宽 (GB/s) | 加速比 | |||||
FD | CU | FD | CU | 对比 FD | 对比 CU 基线 | |||
基准 | 优化 1 | 基准 | 优化 1 | |||||
32 | 137 | 143 | 134 | 262 | 250 | 267 | 1.02 | 1.07 |
64 | 234 | 257 | 237 | 305 | 278 | 302 | 0.99 | 1.09 |
128 | 432 | 455 | 422 | 331 | 314 | 339 | 1.02 | 1.08 |
256 | 815 | 866 | 806 | 351 | 331 | 355 | 1.01 | 1.07 |
512 | 1581 | 1659 | 1550 | 362 | 345 | 369 | 1.02 | 1.07 |
优化 2:改进 P
类型转换 (FP32->BF16)
问题分析
由于 softmax(bmm(Q, KT) / sqrt(D))
的乘积是 FP32(在图 3 中表示为 P
),因此内核必须将 P
从 FP32 转换为 BF16,然后才能将其馈送到下一个 bmm
计算。内核通过将 FP32 数据从共享内存中的一个位置复制到共享内存中的另一个位置来执行 P
的 FP32 到 BF16 转换。这会在共享内存访问期间导致停顿(如图 7 所示),这可能是由 (1) 共享内存间接寻址;和 (2) 共享内存 bank 冲突引起的,因为每个线程访问一个 16 位元素(因此,两个线程可以同时访问同一个内存 bank)。

<强>图 7强> 优化 2 之前的 P
类型转换(箭头指向的数字是共享内存等待造成的停顿周期)
解决方案
我们使用线程块中的所有线程进行就地类型转换。每个线程操作两个连续的元素以避免存储 BF16 时共享内存 bank 冲突。所有线程同时处理相同的头 (h
) 以保证转换的正确性。就地转换步骤如下:
- 每个线程将来自共享内存的同一头中的 2 个 FP32 令牌元素加载到寄存器中
- 调用
__syncthreads()
以确保每个线程完成数据读取 - 每个线程将其数据转换为 2 个 BF16 令牌元素,然后将结果存储到相同的共享内存中
我们应用于实现的一些优化:
- 使用向量类型(尤其是
nv_bfloat2
) - 展开数据加载/存储,即在调用
__syncthreads()
之前执行多次加载,并在__syncthreads()
之后执行多次存储
经过此优化后,在 P
类型转换期间未观察到长时间停顿,如图 8 所示。

<强>图 8强> 优化 2 后的 P
类型转换(箭头指向的数字是共享内存等待造成的停顿周期)
问题
由于我们通过使用寄存器作为中间存储来展开数据加载/存储,因此每个线程的寄存器数量增加,导致占用率降低。
结果
<强>表 4强> INT4 GQA 优化 2 的性能(行式量化)
批次大小 | 时间 (us) | 带宽 (GB/s) | 加速比 | |||||
FD | CU | FD | CU | 对比 FD | 对比 CU 基线 | |||
基准 | 优化 2 | 基准 | 优化 2 | |||||
32 | 137 | 143 | 126 | 262 | 250 | 285 | 1.09 | 1.14 |
64 | 234 | 257 | 221 | 305 | 278 | 324 | 1.06 | 1.16 |
128 | 432 | 455 | 395 | 331 | 314 | 362 | 1.09 | 1.15 |
256 | 815 | 866 | 749 | 351 | 331 | 382 | 1.09 | 1.16 |
512 | 1581 | 1659 | 1435 | 362 | 345 | 399 | 1.10 | 1.16 |
优化 3:移除 max QKT
计算的局部内存使用
问题分析
在 softmax 计算期间,内核必须为每个头计算 max QKT
。它使用一个临时“线程局部”存储来存储每个线程的 max QKT
结果(每个头一个浮点值)。根据编译器,线程局部存储可以分配在寄存器(片上)或局部内存(片外 == 全局内存)上。不幸的是,在基线中,线程局部存储驻留在局部内存中,这比寄存器慢得多(如图 9 所示)。我们怀疑这是因为编译器无法在编译时确定线程局部存储的索引(因为内核中的头数 (H
) 是一个运行时变量)。访问局部内存就像访问寄存器一样会损害内核的性能。

<强>图 9强> max QKT
计算期间的局部内存访问
解决方案
我们意识到每个线程不需要 H
(头数)浮点数作为临时存储,因为每个线程可以只计算一个头的 max QKT
,而不是所有头。因此,我们每个线程只需要一个浮点数,可以轻松存储在寄存器中。为了在 warp 之间累积最大结果,我们使用共享内存。此优化消除了 max QKT
计算期间的局部内存使用。
结果
<强>表 5强> INT4 GQA 优化 3 的性能(行式量化)
批次大小 | 时间 (us) | 带宽 (GB/s) | 加速比 | |||||
FD | CU | FD | CU | 对比 FD | 对比 CU 基线 | |||
基准 | 优化 3 | 基准 | 优化 3 | |||||
32 | 137 | 143 | 119 | 262 | 250 | 300 | 1.14 | 1.20 |
64 | 234 | 257 | 206 | 305 | 278 | 348 | 1.14 | 1.25 |
128 | 432 | 455 | 368 | 331 | 314 | 389 | 1.17 | 1.24 |
256 | 815 | 866 | 696 | 351 | 331 | 411 | 1.17 | 1.24 |
512 | 1581 | 1659 | 1338 | 362 | 345 | 428 | 1.18 | 1.24 |
优化 4:移除行和的局部内存使用
问题分析
与 优化 3 类似,在 softmax
计算的行和计算过程中也观察到局部内存使用问题。由于局部内存位于芯片外,像访问寄存器一样访问它会损害内核的性能。
解决方案:
我们对行和计算应用了与 max QKT
计算相同的解决方案。也就是说,让每个线程只计算一个头的行和,这只需要每个线程一个浮点数。这消除了对局部内存的需求。
结果
<强>表 6强> INT4 GQA 优化 4 的性能(行式量化)
批次大小 | 时间 (us) | 带宽 (GB/s) | 加速比 | |||||
FD | CU | FD | CU | 对比 FD | 对比 CU 基线 | |||
基准 | 优化 4 | 基准 | 优化 4 | |||||
32 | 137 | 143 | 118 | 262 | 250 | 302 | 1.15 | 1.21 |
64 | 234 | 257 | 204 | 305 | 278 | 351 | 1.15 | 1.26 |
128 | 432 | 455 | 364 | 331 | 314 | 393 | 1.19 | 1.25 |
256 | 815 | 866 | 688 | 351 | 331 | 416 | 1.18 | 1.26 |
512 | 1581 | 1659 | 1328 | 362 | 345 | 431 | 1.19 | 1.25 |
优化 5:添加 V
加载的预取
问题分析
在加载 V
时观察到与 K
加载相同的问题。也就是说,内核发出数据加载,然后立即等待使用数据,导致全局加载延迟暴露。然而,当使用上述展开技术时,编译器将临时缓冲区分配到局部内存而不是寄存器,导致速度大大降低。
解决方案
我们采用数据预取技术进行 V
加载。我们在当前迭代值消耗后立即加载下一迭代的 V
值。这允许数据加载与 PK
计算重叠,从而带来更好的内核性能。
结果
<强>表 7强> INT4 GQA 优化 5 的性能(行式量化)
批次大小 | 时间 (us) | 带宽 (GB/s) | 加速比 | |||||
FD | CU | FD | CU | 对比 FD | 对比 CU 基线 | |||
基准 | 优化 5 | 基准 | 优化 5 | |||||
32 | 137 | 143 | 109 | 262 | 250 | 327 | 1.25 | 1.31 |
64 | 234 | 257 | 194 | 305 | 278 | 370 | 1.21 | 1.33 |
128 | 432 | 455 | 345 | 331 | 314 | 414 | 1.25 | 1.32 |
256 | 815 | 866 | 649 | 351 | 331 | 441 | 1.26 | 1.33 |
512 | 1581 | 1659 | 1244 | 362 | 345 | 460 | 1.27 | 1.33 |
优化 6:添加分组式 INT4(分组数 = 4)与向量加载
问题分析
在此优化之前,CU 仅支持行式 INT4 量化。也就是说,每行中的每个列都共享相同的比例。每行的比例存储在每行的前 4 个字节中,如图 10 所示。在内核中,每个线程一次只加载一行。由于每行包含 68 字节(4 字节用于比例,64 字节用于数据),因此不能保证每行与任何向量类型的大小对齐。因此,不能使用向量加载来加载 KV 缓存。

<强>图 10强> 采用行式量化的 INT4 KV 缓存每行的布局
解决方案
我们已经实现了对分组式 INT4 量化的支持,分组数为 4。在这种情况下,KV 缓存张量中每行中的列被分成 4 个相等的分组。同一分组中的列共享相同的量化/反量化比例。INT4 KV 缓存的数据布局如图 11 所示。所有分组的比例都序列化并存储在每行的开头。INT4 数据也序列化并放置在比例旁边。
由于每行的字节数现在变为 80 字节,我们可以使用向量类型(在我们的案例中是 uint2
)来加载数据。(我们<强>不强>使用 uint4
,因为由于张量核心片段大小,每个线程一次只加载 16 个 INT4。)向量加载通常优于标量加载,因为它不会导致额外的字节加载。

<强>图 11强> 采用行式量化的 INT4 KV 缓存每行的布局
结果
<强>表 8强> INT4 GQA 优化 6 的性能(行式量化)
批次大小 | 时间 (us) | 带宽 (GB/s) | 加速比 | |||||
FD | CU | FD | CU | 对比 FD | 对比 CU 基线 | |||
基准 | 优化 6 | 基准 | 优化 6 | |||||
32 | 137 | 143 | 111 | 262 | 250 | 322 | 1.23 | 1.29 |
64 | 234 | 257 | 192 | 305 | 278 | 372 | 1.22 | 1.34 |
128 | 432 | 455 | 346 | 331 | 314 | 414 | 1.25 | 1.32 |
256 | 815 | 866 | 642 | 351 | 331 | 446 | 1.27 | 1.35 |
512 | 1581 | 1659 | 1244 | 362 | 345 | 460 | 1.27 | 1.33 |
<强>表 9强> INT4 GQA 优化 6 的性能(分组式量化,分组数 = 4)
批次大小 | 时间 (us) | 带宽 (GB/s) | 加速比 | ||
FD | CUDA_WMMA | FD | CUDA_WMMA | 对比 FD | |
优化 6 | 优化 6 | ||||
32 | 129 | 116 | 325 | 364 | 1.31 |
64 | 219 | 195 | 385 | 431 | 1.36 |
128 | 392 | 347 | 429 | 484 | 1.39 |
256 | 719 | 638 | 468 | 527 | 1.41 |
512 | 1375 | 1225 | 489 | 550 | 1.43 |
优化 7:直接从 WMMA 片段计算 max QKT
(A100/H100 特定)
问题分析
我们观察到在 max QKT
计算期间由于共享内存访问而导致的长时间停顿(显示为长的短记分牌停顿),如图 12 所示。

<强>图 12强> max QKT
计算期间由于共享内存访问而导致的停顿(箭头指向的数字是共享内存等待造成的停顿周期)
解决方案
我们通过直接从 WMMA 片段(即张量核心片段)计算 max QKT
来绕过共享内存。WMMA 片段的布局特定于 GPU 架构。在此优化中,我们仅为 NVIDIA A100/H100 GPU 启用了此优化。其他 GPU 仍将使用共享内存进行 max QKT
计算。通过绕过共享内存,我们有效地消除了由共享内存访问引起的停顿。用于存储 QKT
结果的 C
片段的张量核心布局如图 13 所示。

<强>图 13强> A100/H100 上的 C
片段 (QKT
存储) 张量核心布局
<强>表 10强> INT4 GQA 优化 7 的性能(行式量化)
批次大小 | 时间 (us) | 带宽 (GB/s) | 加速比 | |||||
FD | CU | FD | CU | 对比 FD | 对比 CU 基线 | |||
基准 | 优化 7 | 基准 | 优化 7 | |||||
32 | 137 | 143 | 107 | 262 | 250 | 333 | 1.27 | 1.33 |
64 | 234 | 257 | 183 | 305 | 278 | 391 | 1.28 | 1.40 |
128 | 432 | 455 | 333 | 331 | 314 | 430 | 1.30 | 1.37 |
256 | 815 | 866 | 620 | 351 | 331 | 461 | 1.31 | 1.40 |
512 | 1581 | 1659 | 1206 | 362 | 345 | 475 | 1.31 | 1.38 |
<强>表 11强> INT4 GQA 优化 7 的性能(分组式量化,分组数 = 4)
批次大小 | 时间 (us) | 带宽 (GB/s) | 加速比 | |||||
FD | CUDA_WMMA | FD | CUDA_WMMA | 对比 FD | 对比 CUDA_WMMA 优化 6 | |||
优化 6 | 优化 7 | 优化 6 | 优化 7 | |||||
32 | 129 | 116 | 111 | 325 | 364 | 380 | 1.17 | 1.04 |
64 | 219 | 195 | 187 | 385 | 431 | 449 | 1.17 | 1.04 |
128 | 392 | 347 | 333 | 429 | 484 | 506 | 1.18 | 1.04 |
256 | 719 | 638 | 615 | 468 | 527 | 547 | 1.17 | 1.04 |
512 | 1375 | 1225 | 1184 | 489 | 550 | 569 | 1.16 | 1.03 |
优化 8:将 FP32->BF16 结果直接写入 P
片段(A100/H100 特定)
问题分析
在 P
片段的 FP32-BF16 转换过程中,内核从共享内存加载 FP32 数据,进行转换,然后将 BF16 数据存储回共享内存。此外,转换需要许多线程块同步 (__syncthreads()
)。
解决方案
由于内核的数据分区设计,每个 warp 只进行一次 P
片段的遍历。因此,我们不必将转换结果写回共享内存以供将来使用。为了避免将 BF16 数据写入共享内存和线程块同步,我们让每个 warp 从共享内存加载 P
WMMA 片段的 FP32 数据,进行转换,然后直接将 BF16 数据写入 P
片段。
请注意,此优化仅适用于 NVIDIA A100 和 H100 GPU,因为 WMMA 片段布局取决于架构。对于非 A100/H100 GPU,内核将回退到原始路径。
P
片段张量核心布局如图 14 所示。请注意,此布局特定于 NVIDIA A100/H100 GPU。

<强>图 14强> A100/H100 上的 P
片段张量核心布局
<强>表 12强> INT4 GQA 优化 8 的性能(行式量化)
批次大小 | 时间 (us) | 带宽 (GB/s) | 加速比 | |||||
FD | CU | FD | CU | 对比 FD | 对比 CU 基线 | |||
基准 | 优化 8 | 基准 | 优化 8 | |||||
32 | 137 | 143 | 101 | 262 | 250 | 353 | 1.35 | 1.41 |
64 | 234 | 257 | 174 | 305 | 278 | 410 | 1.34 | 1.47 |
128 | 432 | 455 | 317 | 331 | 314 | 451 | 1.36 | 1.43 |
256 | 815 | 866 | 590 | 351 | 331 | 485 | 1.38 | 1.47 |
512 | 1581 | 1659 | 1143 | 362 | 345 | 501 | 1.38 | 1.45 |
<强>表 13强> INT4 GQA 优化 8 的性能(分组式量化,分组数 = 4)
批次大小 | 时间 (us) | 带宽 (GB/s) | 加速比 | |||||
FD | CUDA_WMMA | FD | CUDA_WMMA | 对比 FD | 对比 CUDA_WMMA 优化 6 | |||
优化 6 | 优化 8 | 优化 6 | 优化 8 | |||||
32 | 129 | 116 | 106 | 325 | 364 | 396 | 1.22 | 1.09 |
64 | 219 | 195 | 180 | 385 | 431 | 467 | 1.21 | 1.08 |
128 | 392 | 347 | 319 | 429 | 484 | 528 | 1.23 | 1.09 |
256 | 719 | 638 | 596 | 468 | 527 | 565 | 1.21 | 1.07 |
512 | 1375 | 1225 | 1138 | 489 | 550 | 591 | 1.21 | 1.08 |
优化 9:混排 P 共享内存布局(A100/H100 特定)
问题分析
我们观察到在 P
加载期间存在大量的共享内存 bank 冲突。bank 冲突的数量取决于内存访问步幅。例如,对于 split-Ks = 32 和最大序列长度 = 8192,我们观察到 32 个 bank 中只有 4 个并行访问(内存访问步幅 = 256)。从图 14 中可以看出,当所有线程访问元素 0 时,具有相同 threadIdx.x % 4
的线程会访问相同的 bank。

<强>图 15强> 混排前共享内存中的 P 片段
解决方案
我们对共享内存中 P
加载/存储的布局进行混排,以避免 bank 冲突。换句话说,我们使用混排布局存储 QKT
结果(C
片段)并加载它们(P
片段)。此外,我们不使用依赖于每个线程块令牌数量的原始内存访问步幅,而是使用常量片段的列大小作为步幅。因此,P
片段的加载和存储始终是连续的。
C 和 P 片段的新布局如图 16 所示。通过新布局,可以保证 16 个 bank 并行访问,如图 17 所示。

<强>图 16强> C 和 P 片段的混排布局

<强>图 17强> 混排后共享内存中的 P 片段
<强>表 14强> INT4 GQA 优化 9 的性能(行式量化)
批次大小 | 时间 (us) | 带宽 (GB/s) | 加速比 | |||||
FD | CU | FD | CU | 对比 FD | 对比 CU 基线 | |||
基准 | 优化 9 | 基准 | 优化 9 | |||||
32 | 137 | 143 | 98 | 262 | 250 | 365 | 1.39 | 1.46 |
64 | 234 | 257 | 167 | 305 | 278 | 429 | 1.41 | 1.54 |
128 | 432 | 455 | 299 | 331 | 314 | 479 | 1.45 | 1.52 |
256 | 815 | 866 | 549 | 351 | 331 | 521 | 1.48 | 1.58 |
512 | 1581 | 1659 | 1060 | 362 | 345 | 540 | 1.49 | 1.56 |
<强>表 15强> INT4 GQA 优化 9 的性能(分组式量化,分组数 = 4)
批次大小 | 时间 (us) | 带宽 (GB/s) | 加速比 | |||||
FD | CUDA_WMMA | FD | CUDA_WMMA | 对比 FD | 对比 CUDA_WMMA 优化 6 | |||
优化 6 | 优化 9 | 优化 6 | 优化 9 | |||||
32 | 129 | 116 | 105 | 325 | 364 | 400 | 1.23 | 1.10 |
64 | 219 | 195 | 174 | 385 | 431 | 484 | 1.26 | 1.12 |
128 | 392 | 347 | 302 | 429 | 484 | 558 | 1.30 | 1.15 |
256 | 719 | 638 | 560 | 468 | 527 | 601 | 1.28 | 1.14 |
512 | 1375 | 1225 | 1065 | 489 | 550 | 632 | 1.29 | 1.15 |
优化 10:为 INT4 反量化填充共享内存
问题分析
一旦内核从全局内存读取 INT4 K
或 V
缓存,它会执行反量化并将结果(BF16)存储在共享内存中。然后,BF16 数据从共享内存加载到 WMMA 片段(通过 WMMA 接口)。我们观察到 K
和 V
访问都存在大量的 bank 冲突。例如,对于 K
存储,32 个 bank 中只有 4 个并行访问。对于 K
加载,16 个 bank 并行访问。V
存储和加载也发生同样的情况。请参见解决方案部分中的图。
解决方案
我们填充共享内存以减少 bank 冲突。具体来说,我们为每行填充 2。也就是说,K
的行步幅变为 F_K
+ 2,V 的行步幅变为 F_N
+ 2(F_K
和 F_N
分别是 K
和 V
WMMA 片段的固定宽度)。通过此优化,我们将 bank 冲突减少了 1.8 倍,如图 18 所示。

<强>图 18强> 优化 10 前后的 Bank 冲突
优化 10 后,对于 K
存储,32 个 bank 并行访问(如图 19 所示),而对于 K
加载,29 个 bank 并行访问(如图 20 所示)。

<强>图 19强> K 片段存储共享内存布局,不带填充和带填充

<强>图 20强> K 片段加载共享内存布局,不带填充和带填充
<强>表 16强> INT4 GQA 优化 10 的性能(行式量化)
批次大小 | 时间 (us) | 带宽 (GB/s) | 加速比 | |||||
FD | CU | FD | CU | 对比 FD | 对比 CU 基线 | |||
基准 | 优化 10 | 基准 | 优化 10 | |||||
32 | 137 | 143 | 94 | 262 | 250 | 380 | 1.45 | 1.52 |
64 | 234 | 257 | 151 | 305 | 278 | 475 | 1.55 | 1.71 |
128 | 432 | 455 | 266 | 331 | 314 | 538 | 1.63 | 1.71 |
256 | 815 | 866 | 489 | 351 | 331 | 586 | 1.67 | 1.77 |
512 | 1581 | 1659 | 930 | 362 | 345 | 616 | 1.70 | 1.79 |
<强>表 17强> INT4 GQA 优化 10 的性能(分组式量化,分组数 = 4)
批次大小 | 时间 (us) | 带宽 (GB/s) | 加速比 | |||||
FD | CUDA_WMMA | FD | CUDA_WMMA | 对比 FD | 对比 CUDA_WMMA 优化 6 | |||
优化 6 | 优化 10 | 优化 6 | 优化 10 | |||||
32 | 129 | 116 | 99 | 325 | 364 | 425 | 1.31 | 1.17 |
64 | 219 | 195 | 161 | 385 | 431 | 523 | 1.36 | 1.21 |
128 | 392 | 347 | 282 | 429 | 484 | 598 | 1.39 | 1.23 |
256 | 719 | 638 | 509 | 468 | 527 | 662 | 1.41 | 1.25 |
512 | 1375 | 1225 | 965 | 489 | 550 | 698 | 1.43 | 1.27 |
性能评估
微基准测试结果
我们还使用我们优化后的内核评估了 BF16 GQA 性能(如表 19 所示)。对于 BF16,CU 的性能通常仍比 FD 和 FA 差。这是预期的,因为我们的优化是针对 INT4 的。
虽然 INT4 GQA 的效率仍不如 BF16 GQA(参见实现的带宽),但值得注意的是,当比较 FD BF16 GQA 性能与 CU INT4 GQA 性能时,<强>我们可以看到 INT4 的延迟小于 BF16 的延迟强>。
<强>表 19强> CU 优化后 BF16 GQA 和 INT GQA 的性能
在 A100 上
时间 (us) | BF16 GQA | INT4 GQA | ||||||
批次大小 | FD | FA | CU 之前 | CU 之后 | FD | FA | CU 之前 | CU 之后 |
32 | 139 | 133 | 183 | 163 | 137 | – | 143 | 94 |
64 | 245 | 229 | 335 | 276 | 234 | – | 257 | 151 |
128 | 433 | 555 | 596 | 517 | 432 | – | 455 | 266 |
256 | 826 | 977 | 1127 | 999 | 815 | – | 866 | 489 |
512 | 1607 | 1670 | 2194 | 1879 | 1581 | – | 1659 | 930 |
有效带宽 (GB/s) | BF16 GQA | INT4 GQA | ||||||
批次大小 | FD | FA | CU 之前 | CU 之后 | FD | FA | CU 之前 | CU 之后 |
32 | 965 | 1012 | 736 | 824 | 262 | – | 250 | 380 |
64 | 1097 | 1175 | 802 | 972 | 305 | – | 278 | 475 |
128 | 1240 | 968 | 901 | 1039 | 331 | – | 314 | 538 |
256 | 1301 | 1100 | 954 | 1075 | 351 | – | 331 | 586 |
512 | 1338 | 1287 | 980 | 1144 | 362 | – | 345 | 616 |
在 H100 上
时间 (us) | BF16 GQA | INT4 GQA | ||||||
批次大小 | FD | FA | CU 之前 | CU 之后 | FD | FA | CU 之前 | CU 之后 |
32 | 91 | 90 | 114 | 100 | 70 | – | 96 | 64 |
64 | 148 | 146 | 200 | 183 | 113 | – | 162 | 101 |
128 | 271 | 298 | 361 | 308 | 205 | – | 294 | 170 |
256 | 515 | 499 | 658 | 556 | 389 | – | 558 | 306 |
512 | 1000 | 1011 | 1260 | 1066 | 756 | – | 1066 | 575 |
有效带宽 (GB/s) | BF16 GQA | INT4 GQA | ||||||
批次大小 | FD | FA | CU 之前 | CU 之后 | FD | FA | CU 之前 | CU 之后 |
32 | 1481 | 1496 | 1178 | 1341 | 511 | – | 371 | 560 |
64 | 1815 | 1840 | 1345 | 1470 | 631 | – | 443 | 710 |
128 | 1982 | 1802 | 1487 | 1743 | 699 | – | 487 | 844 |
256 | 2087 | 2156 | 1634 | 1934 | 736 | – | 513 | 935 |
512 | 2150 | 2127 | 1706 | 2015 | 757 | – | 537 | 996 |
端到端结果
我们在 8 个 H100 GPU 上评估了 Llama 2 70B 中我们优化后的 INT4 GQA 内核。我们运行了端到端模型,但仅报告了解码延迟。我们使用 FP8 FFN(前馈网络)来强调解码阶段的注意力性能。我们将批处理大小从 1 更改为 256,将上下文长度从 2,048 (2K) 更改为 16,384 (16K)。端到端性能结果如下图所示。

<强>图 21强> Meta Llama 2 解码延迟 (ms) 比较(BF16 GQA 在大批处理大小配置中内存不足)
代码
如果您有兴趣,请在此处查看我们的代码:https://github.com/pytorch/FBGEMM/tree/main/fbgemm_gpu/experimental/gen_ai。如果您有任何问题,请随时在 GitHub 上提出问题,我们将很乐意提供帮助。欢迎您的贡献!