一种高效的采用低精度 KV 缓存的解码分组查询注意力
引言
生成式 AI 以其像人类一样生成内容的能力席卷全球。许多此类生成式 AI 工具都由大型语言模型(LLM)驱动,例如 Meta 的 Llama 模型和 OpenAI 的 ChatGPT。LLM 面临的主要挑战之一是支持较长的“上下文长度”(也称为“序列长度”)。上下文长度是指模型用于理解输入上下文和生成响应的 token 数量。更长的上下文长度通常意味着响应具有更高的精度和质量。然而,长上下文长度对计算和内存要求很高。这主要是由于以下原因:
- 注意力层的计算复杂度随上下文长度线性增加(增长率取决于注意力算法)。因此,在使用长上下文长度时,注意力层可能成为瓶颈,尤其是在预填充阶段,该阶段的注意力计算受计算能力限制。
- KV 缓存的大小随上下文长度线性增长,从而对内存需求造成更大压力,并因此减缓了本已受内存限制的注意力解码速度。此外,由于内存容量有限,当 KV 缓存变大时,批量大小会减小,这通常会导致吞吐量下降。
与上面提到的其他问题相比,计算复杂度的增长更难解决。解决 KV 缓存大小增长问题的一种方法是使用低精度 KV 缓存。从我们的实验来看,在 Meta Llama 2 推理的解码阶段,分组式 INT4 量化在精度方面与 BF16 KV 缓存提供了可比的结果。然而,尽管在注意力解码层读取的数据量减少了 4 倍,但我们并未观察到任何延迟改善。这意味着 INT4 注意力在利用宝贵的 HBM 带宽方面比 BF16 注意力效率低 4 倍。
在本文中,我们将讨论我们对 INT4 GQA(分组查询注意力——我们在 LLM 推理阶段使用的注意力层)应用的 CUDA 优化,以将其性能提高多达 在 NVIDIA A100 GPU 上达到 1.8 倍 和 在 NVIDIA H100 GPU 上达到 1.9 倍。
- 优化的 CUDA INT4 GQA 在性能上优于 INT4 Flash-Decoding GQA(我们在上述实验中使用的性能最佳的 INT4 GQA),优于幅度为 在 A100 上为 1.4 倍至 1.7 倍,以及 在 H100 上为 1.09 倍至 1.3 倍。
- 优化的 CUDA INT4 GQA 性能优于 BF16 Flash-Decoding GQA,优于幅度为 在 A100 上为 1.5 倍至 1.7 倍,在 H100 上为 1.4 倍至 1.7 倍。
背景
用于 LLM 推理的 GQA
分组查询注意力 (GQA) 是多头注意力(MHA)的一种变体,其中每个 KV 缓存头在一组查询头之间共享。我们的 LLM 推理在预填充和解码阶段都采用 GQA 作为注意力层,以减少对 KV 缓存的容量需求。我们在推理中使用多个 GPU,其中 KV 缓存和查询头分布在不同的 GPU 上。每个 GPU 运行一个注意力层,包含一个 KV 头和一组 Q 头。因此,从单个 GPU 的角度来看,GQA 组件也可以描述为 MQA(多查询注意力)。
解码 GQA 的简化工作流程如图 1 所示。GQA 接收三个主要输入:输入查询(表示为 Q
)、K 缓存(表示为 K
)和 V 缓存(表示为 V
)。我们当前的 GQA 推理对 Q
、K
和 V
使用 BF16 类型。
Q
是一个形状为 (B
,1
,HQ
,D
) 的 4D BF16 张量K
是一个形状为 (B
,Tmax
,HKV
,D
) 的 4D BF16 张量V
是一个形状为 (B
,Tmax
,HKV
,D
) 的 4D BF16 张量
其中
B
是批量大小(输入提示的数量)HQ
是查询头的数量HKV
是 KV 头的数量(HQ
必须能被HKV
整除)Tmax
是最大上下文长度D
是头维度(固定为 128)
GQA 简单来说就是 bmm(softmax(bmm(Q, KT) / sqrt(D)), V)
。这产生一个输出张量(表示为 O
),它是一个形状与 Q
相同的 4D BF16 张量。注意,矩阵乘法使用 BF16 执行,但累加和 softmax
在 FP32 中进行。由于 KV 缓存是 BF16 类型,我们称之为“BF16 GQA”。
图 1 用于 LLM 推理的 BF16 GQA 的简化工作流程
INT4 GQA
为了进一步减小 KV 缓存的大小,我们探索了使用 INT4 代替 BF16 作为 KV 缓存的可能性。我们通过计算 INT4 GQA 的计算强度(CI)并将其与 BF16 GQA 的计算强度进行比较来估计潜在的性能提升,因为 CI 代表每字节的 FLOPS。我们计算 QKT
和 PV
的 CI(如公式 1 所示),因为它们将 KV 缓存作为操作数。注意,我们忽略 Q
的加载,因为它与 KV 缓存相比可以忽略不计。我们也忽略任何不在全局内存上的中间数据加载/存储。因此,CI 只考虑计算 FLOPS 和 KV 缓存加载。
公式 (1)
假设 HQ
= 8 且 HKV
= 1,BF16 KV 缓存的 CI 为 8,而 INT4 KV 缓存的 CI 为 32。这些 CI 表明 BF16 和 INT4 GQA 都受内存限制(A100 和 H100 的 BF16 张量核的峰值 CI 分别为 312 TF / 2 TB/s = 141 和 990 TF / 3.35 TB/s = 269;请注意,这些 TF 数字未考虑稀疏性)。此外,使用 INT4 KV 缓存,我们应该期望性能比 BF16 GQA 提升多达 4 倍。
为了在 GQA 中启用 INT4 KV 缓存支持,我们可以在将 KV 缓存传递给 BF16 GQA 算子之前,将其从 INT4 反量化为 BF16。然而,由于 KV 缓存通常很大,将其从全局内存复制或复制到全局内存可能会耗费大量开销。此外,解码 GQA 是一种受内存限制的操作(内存单元比计算单元利用得更重)。图 2 显示了 xFormers 中的 FMHA CUTLASS BF16 GQA kernel 的 NCU 性能分析结果,它是 GQA 最先进的实现之一。从图上看,显然内存是瓶颈。
图 2 xFormers 中的 FMHA CUTLASS BF16 kernel 的 NCU 性能分析结果
一种更有效的替代方案是将 INT4 反量化与 GQA 操作融合(如图 3 所示)。换句话说,让 GQA 直接读取 INT4 KV 缓存,并在 kernel 中执行 INT4 到 BF16 的转换。这一改变可以潜在地减少 KV 缓存所需的全局内存读取量,这可能会导致延迟降低。我们称之为“INT4 GQA”。
图 3 融合 INT4 GQA 的工作流程
我们在下表中列出了 GQA 最先进的实现及其特性(表 1)。
表 1 最先进的 GQA 实现
实现 | 表示 | BF16 GQA | 融合 INT4 GQA |
Flash-Decoding (Triton 实现) | FD | 是 | 是 |
Flash Attention (v2.3.3) | FA | 是 | 否 |
CUDA 基准 | CU | 是 | 是 |
所有实现(CU 除外)都支持 split-K 和非 split-K。CU 只有 split-K 实现。只有 FA 在后端具有启发式算法,用于确定是运行 split-K 还是非 split-K kernel。对于其他实现,用户必须明确选择要运行的版本。在本文中,我们重点关注长上下文长度(在我们的实验中,我们使用上下文长度为 8192),因此尽可能选择 split-K 版本。
作为基准,我们在 NVIDIA A100 和 H100 GPU 上测量了最先进 GQA 实现的性能。延迟(以微秒为单位的时间)和实现的带宽(GB/s)报告在表 2 中。注意,我们运行了一系列 split-K(从 2 到 128 个 split),并报告了每个实现的最佳性能。对于所有实验,我们使用上下文长度 8192。对于 INT4 GQA,我们使用了按行量化(即量化组数 = 1)。
表 2 基准 GQA 性能
在 A100 上
时间 (微秒) | BF16 GQA | INT4 GQA | ||||
批量大小 | FD | FA | CU | FD | FA | CU |
32 | 139 | 133 | 183 | 137 | - | 143 |
64 | 245 | 229 | 335 | 234 | - | 257 |
128 | 433 | 555 | 596 | 432 | - | 455 |
256 | 826 | 977 | 1127 | 815 | - | 866 |
512 | 1607 | 1670 | 2194 | 1581 | - | 1659 |
有效带宽 (GB/s) | BF16 GQA | INT4 GQA | ||||
批量大小 | FD | FA | CU | FD | FA | CU |
32 | 965 | 1012 | 736 | 262 | - | 250 |
64 | 1097 | 1175 | 802 | 305 | - | 278 |
128 | 1240 | 968 | 901 | 331 | - | 314 |
256 | 1301 | 1100 | 954 | 351 | - | 331 |
512 | 1338 | 1287 | 980 | 362 | - | 345 |
在 H100 上
时间 (微秒) | BF16 GQA | INT4 GQA | ||||
批量大小 | FD | FA | CU | FD | FA | CU |
32 | 91 | 90 | 114 | 70 | - | 96 |
64 | 148 | 146 | 200 | 113 | - | 162 |
128 | 271 | 298 | 361 | 205 | - | 294 |
256 | 515 | 499 | 658 | 389 | - | 558 |
512 | 1000 | 1011 | 1260 | 756 | - | 1066 |
有效带宽 (GB/s) | BF16 GQA | INT4 GQA | ||||
批量大小 | FD | FA | CU | FD | FA | CU |
32 | 1481 | 1496 | 1178 | 511 | - | 371 |
64 | 1815 | 1840 | 1345 | 631 | - | 443 |
128 | 1982 | 1802 | 1487 | 699 | - | 487 |
256 | 2087 | 2156 | 1634 | 736 | - | 513 |
512 | 2150 | 2127 | 1706 | 757 | - | 537 |
首先,我们来讨论 BF16 GQA 的性能:在所有实现中,CU 的性能排名最后。FD 和 FA 的性能相当。当批量大小小于或等于 64 时,FA 利用 split-K kernel,性能略优于 FD。然而,当批量大小大于 64 时,FD 性能更佳。
同样的趋势也适用于 INT4 GQA。然而,由于 FA 不支持 INT4 KV 缓存,我们没有测量其性能。在所有情况下,FD 都优于 CU。
比较 BF16 和 INT4 GQA 的 FD 延迟时,我们发现它们几乎相同。这表明 INT4 GQA 的效率非常低,INT4 GQA 可实现的带宽显著低于 BF16 GQA,这进一步证实了这一点。查看 CU 的性能时,同样的趋势也适用。
使用张量核的 CUDA INT4 GQA 实现
在本节中,我们将简要介绍我们的基准实现,即使用张量核的 CUDA INT4 GQA(CU)。每个线程块仅处理一个 KV 头和来自一个输入提示的一组查询头。因此,每个线程块执行 mm(softmax(mm(Q, KT) / sqrt(D)), V)
;请注意,这里执行的是 mm
而非 bmm
。此外,由于这是一个 split-K 实现,KV 缓存中的 token 分割到不同的线程块中。注意,每个线程块包含 4 个 warp(对于 NVIDIA A100 和 H100 GPU,每个 warp 包含 32 个线程)。每个线程块中的工作分配给不同的 warp。在每个 warp 中,我们使用 WMMA API 在张量核上计算矩阵乘法。图 4 演示了 CU 中的工作划分。
图 4 CU 工作划分
优化使用张量核的 INT4 GQA CUDA Kernel
在本文中,我们将讨论我们对使用张量核的 CUDA INT4 GQA 实现(CU)所应用的优化。理想目标是基于前一节中的 CI 分析,将 INT4 GQA 的性能提高 4 倍。注意,当上下文长度很长时,查询大小与 KV 缓存大小相比可以忽略不计。
在我们的分析中,我们使用 NVIDIA Nsight Compute (NCU) 作为主要的性能分析器。我们通用的瓶颈消除方法是最小化停滞周期。我们对 INT4 GQA 应用了 10 项优化,其中三项是 NVIDIA A100/H100 GPU 特有的。这些优化是众所周知的 CUDA 优化技术,可以推广到许多应用中。
值得注意的是,我们选择优化 CUDA 实现而非 Flash-Decoding 实现(FD,基于 Triton)的原因在于,使用 CUDA,我们可以更好地控制底层指令的生成方式。我们应用的许多优化技术,例如直接操作张量核片段(优化 7-9),无法通过 Triton 实现,因为它不向开发者暴露底层细节。然而,这些优化可以集成到基于编译器的解决方案中,使这些优化可用于更广泛的算子,这确实是我们未来计划的一部分。
优化 1:展开 K
加载
问题分析
NCU 性能分析结果显示,在 K
加载期间,只有 2 个全局加载,随后在 dequantize_permuted_int4
处出现内存停滞。内存停滞是长的记分板停滞,这表明在等待全局内存访问。这表明 kernel 没有发出足够的内存加载:
以隐藏全局加载延迟。kernel 发出数据加载指令后,立即等待消费数据,从而暴露了全局加载延迟。停滞情况如图 5 所示。
图 5 展开前的 K 加载(箭头指向的数字表示由全局内存等待引起的停滞周期)
解决方案
在基准实现中,我们使用 uint32_t
在一次加载中加载 8 个 INT4 K
值,并且在每次迭代中执行 2 次 uint32_t
加载,即 16 个 INT4 K 值。为了更好地隐藏全局加载延迟,我们在消费 dequantize_permuted_int4
中的 K
值之前,发出 8 次 uint32_t
加载,而非 2 次。这使得编译器能够展开加载并重新排序指令,从而更好地隐藏全局加载延迟。图 6 显示了展开后 K
加载的 NCU 性能分析结果。比较图 5 和图 6,我们通过展开 K
加载,有效减少了停滞周期。
图 6 展开后的 K 加载(箭头指向的数字表示由全局内存等待引起的停滞周期)
结果
表 3 优化 1 对 INT4 GQA 的性能提升(按行量化)
批量大小 | 时间 (微秒) | 带宽 (GB/s) | 加速比 | |||||
FD | CU | FD | CU | 相对于 FD | 相对于 CU 基准 | |||
基准 | 优化 1 | 基准 | 优化 1 | |||||
32 | 137 | 143 | 134 | 262 | 250 | 267 | 1.02 | 1.07 |
64 | 234 | 257 | 237 | 305 | 278 | 302 | 0.99 | 1.09 |
128 | 432 | 455 | 422 | 331 | 314 | 339 | 1.02 | 1.08 |
256 | 815 | 866 | 806 | 351 | 331 | 355 | 1.01 | 1.07 |
512 | 1581 | 1659 | 1550 | 362 | 345 | 369 | 1.02 | 1.07 |
优化 2:改进 P
类型转换(FP32->BF16)
问题分析
由于 softmax(bmm(Q, KT) / sqrt(D))
的结果是 FP32(在图 3 中表示为 P
),kernel 必须在将其馈送到下一个 bmm
计算之前将 P
从 FP32 转换为 BF16。kernel 通过将 FP32 数据从共享内存的一个位置复制到共享内存的另一个位置来执行 P
的 FP32 到 BF16 转换。这导致在访问共享内存期间出现停滞(如图 7 所示),这可能由 (1) 共享内存间接寻址;以及 (2) 共享内存 bank 冲突引起,因为每个线程访问一个 16 位元素(因此,两个线程可以同时访问同一个 memory bank)。
图 7 优化 2 之前的 P
类型转换(箭头指向的数字表示由共享内存等待引起的停滞周期)
解决方案
我们使用线程块中的所有线程进行就地类型转换。每个线程操作两个连续的元素,以避免在存储 BF16 时出现共享内存 bank 冲突。所有线程同时处理同一个头(h
),以保证转换的正确性。就地转换步骤如下:
- 每个线程将同一个头的 2 个 FP32 token 元素从共享内存加载到寄存器中
- 调用
__syncthreads()
以确保每个线程完成数据读取 - 每个线程将其数据转换为 2 个 BF16 token 元素,然后将结果存储回同一个共享内存
我们对实现应用的一些优化:
- 使用向量类型(特别是
nv_bfloat2
) - 展开数据加载/存储,即在调用
__syncthreads()
之前执行多次加载,并在调用__syncthreads()
之后执行多次存储
应用此优化后,在 P
类型转换期间未观察到长停滞,如图 8 所示。
图 8 优化 2 之后的 P
类型转换(箭头指向的数字表示由共享内存等待引起的停滞周期)
潜在问题
由于我们通过使用寄存器作为中间存储来展开数据加载/存储,因此每个线程的寄存器数量增加,导致占用率降低。
结果
表 4 优化 2 对 INT4 GQA 的性能提升(按行量化)
批量大小 | 时间 (微秒) | 带宽 (GB/s) | 加速比 | |||||
FD | CU | FD | CU | 相对于 FD | 相对于 CU 基准 | |||
基准 | 优化 2 | 基准 | 优化 2 | |||||
32 | 137 | 143 | 126 | 262 | 250 | 285 | 1.09 | 1.14 |
64 | 234 | 257 | 221 | 305 | 278 | 324 | 1.06 | 1.16 |
128 | 432 | 455 | 395 | 331 | 314 | 362 | 1.09 | 1.15 |
256 | 815 | 866 | 749 | 351 | 331 | 382 | 1.09 | 1.16 |
512 | 1581 | 1659 | 1435 | 362 | 345 | 399 | 1.10 | 1.16 |
优化 3:移除 max QKT
计算中的局部内存使用
问题分析
在 softmax 计算期间,kernel 必须计算每个头的 max QKT
。它使用一个临时的“线程局部”存储来存储每个线程的 max QKT
结果(每个头一个 float 值)。根据编译器,线程局部存储可以分配在寄存器(片上)或局部内存(片外 == 全局内存)上。不幸的是,在基准中,线程局部存储位于局部内存中,这比寄存器慢得多(如图 9 所示)。我们怀疑这是因为编译器无法在编译时确定线程局部存储的索引(因为 kernel 中的头数量(H
)是运行时变量)。像访问寄存器一样访问局部内存可能会损害 kernel 的性能。
图 9 在 max QKT
计算期间的局部内存访问
解决方案
我们意识到,我们不需要每个线程使用 H
(头数量)个 float 作为临时存储,因为每个线程可以只为一个头计算 max QKT
,而非为所有头计算。因此,我们每个线程只需一个 float,这可以很容易地存储在寄存器中。为了在 warp 之间累加最大值结果,我们使用共享内存。此优化消除了在 max QKT
计算期间对局部内存的使用。
结果
表 5 优化 3 对 INT4 GQA 的性能提升(按行量化)
批量大小 | 时间 (微秒) | 带宽 (GB/s) | 加速比 | |||||
FD | CU | FD | CU | 相对于 FD | 相对于 CU 基准 | |||
基准 | 优化 3 | 基准 | 优化 3 | |||||
32 | 137 | 143 | 119 | 262 | 250 | 300 | 1.14 | 1.20 |
64 | 234 | 257 | 206 | 305 | 278 | 348 | 1.14 | 1.25 |
128 | 432 | 455 | 368 | 331 | 314 | 389 | 1.17 | 1.24 |
256 | 815 | 866 | 696 | 351 | 331 | 411 | 1.17 | 1.24 |
512 | 1581 | 1659 | 1338 | 362 | 345 | 428 | 1.18 | 1.24 |
优化 4:移除行求和中的局部内存使用
问题分析
类似于 优化 3,在 softmax
计算的行求和期间也观察到局部内存使用问题。由于局部内存是片外内存,像访问寄存器一样访问它可能会损害 kernel 的性能。
解决方案:
我们对行求和计算应用了与 max QKT
计算相同的解决方案。即让每个线程只计算一个头的行求和,这每个线程只需一个 float。这消除了对局部内存的需求。
结果
表 6 优化 4 对 INT4 GQA 的性能提升(按行量化)
批量大小 | 时间 (微秒) | 带宽 (GB/s) | 加速比 | |||||
FD | CU | FD | CU | 相对于 FD | 相对于 CU 基准 | |||
基准 | 优化 4 | 基准 | 优化 4 | |||||
32 | 137 | 143 | 118 | 262 | 250 | 302 | 1.15 | 1.21 |
64 | 234 | 257 | 204 | 305 | 278 | 351 | 1.15 | 1.26 |
128 | 432 | 455 | 364 | 331 | 314 | 393 | 1.19 | 1.25 |
256 | 815 | 866 | 688 | 351 | 331 | 416 | 1.18 | 1.26 |
512 | 1581 | 1659 | 1328 | 362 | 345 | 431 | 1.19 | 1.25 |
优化 5:为 V
加载添加预取
问题分析
加载 V
时也观察到与加载 K
相同的问题。也就是说,kernel 发出数据加载指令后,立即等待消费数据,从而暴露了全局加载延迟。然而,当使用上述展开技术时,编译器会将临时缓冲区分配到局部内存而不是寄存器中,导致性能大幅下降。
解决方案
我们为 V
加载采用数据预取技术。在消费当前迭代值后,我们立即加载下一迭代的 V
值。这使得数据加载与 PV
计算重叠,从而提高了 kernel 性能。
结果
表 7 优化 5 对 INT4 GQA 的性能提升(按行量化)
批量大小 | 时间 (微秒) | 带宽 (GB/s) | 加速比 | |||||
FD | CU | FD | CU | 相对于 FD | 相对于 CU 基准 | |||
基准 | 优化 5 | 基准 | 优化 5 | |||||
32 | 137 | 143 | 109 | 262 | 250 | 327 | 1.25 | 1.31 |
64 | 234 | 257 | 194 | 305 | 278 | 370 | 1.21 | 1.33 |
128 | 432 | 455 | 345 | 331 | 314 | 414 | 1.25 | 1.32 |
256 | 815 | 866 | 649 | 351 | 331 | 441 | 1.26 | 1.33 |
512 | 1581 | 1659 | 1244 | 362 | 345 | 460 | 1.27 | 1.33 |
优化 6:添加分组式 INT4(组数 = 4)并使用向量加载
问题分析
在此优化之前,CU 仅支持按行 INT4 量化。也就是说,每行中的每列共享相同的缩放因子。每行的缩放因子存储在每行的前 4 字节中,如图 10 所示。在 kernel 中,每个线程一次只加载一行。由于每行包含 68 字节(4 字节用于缩放因子,64 字节用于数据),因此无法保证每行与任何向量类型的大小对齐。因此,无法使用向量加载来加载 KV 缓存。
图 10 按行量化 INT4 KV 缓存每行的布局
解决方案
我们实现了对分组式 INT4 量化(组数 = 4)的支持。在这种情况下,KV 缓存张量每行中的列被分成 4 个相等的组。同一组内的列共享相同的缩放因子进行量化/反量化。INT4 KV 缓存的数据布局如图 11 所示。所有组的缩放因子被序列化并存储在每行的开头。INT4 数据也被序列化并布局在缩放因子旁边。
由于每行中的字节数现在变为 80 字节,我们可以使用向量类型(即在我们的情况下为 uint2
)来加载数据。(我们不使用 uint4
,因为每个线程一次只加载 16 个 INT4 值,这是由于张量核片段大小的限制。)向量加载通常优于标量加载,因为它不会导致额外的字节加载。
图 11 分组式 INT4 量化(组数 = 4)INT4 KV 缓存每行的布局
结果
表 8 优化 6 对 INT4 GQA 的性能提升(按行量化)
批量大小 | 时间 (微秒) | 带宽 (GB/s) | 加速比 | |||||
FD | CU | FD | CU | 相对于 FD | 相对于 CU 基准 | |||
基准 | 优化 6 | 基准 | 优化 6 | |||||
32 | 137 | 143 | 111 | 262 | 250 | 322 | 1.23 | 1.29 |
64 | 234 | 257 | 192 | 305 | 278 | 372 | 1.22 | 1.34 |
128 | 432 | 455 | 346 | 331 | 314 | 414 | 1.25 | 1.32 |
256 | 815 | 866 | 642 | 351 | 331 | 446 | 1.27 | 1.35 |
512 | 1581 | 1659 | 1244 | 362 | 345 | 460 | 1.27 | 1.33 |
表 9 优化 6 对 INT4 GQA 的性能提升(分组式量化,组数 = 4)
批量大小 | 时间 (微秒) | 带宽 (GB/s) | 加速比 | ||
FD | CUDA_WMMA | FD | CUDA_WMMA | 相对于 FD | |
优化 6 | 优化 6 | ||||
32 | 129 | 116 | 325 | 364 | 1.31 |
64 | 219 | 195 | 385 | 431 | 1.36 |
128 | 392 | 347 | 429 | 484 | 1.39 |
256 | 719 | 638 | 468 | 527 | 1.41 |
512 | 1375 | 1225 | 489 | 550 | 1.43 |
优化 7:直接从 WMMA 片段计算 max QKT
(A100/H100 特有)
问题分析
我们观察到在 max QKT
计算期间,由于访问共享内存而出现大量停滞(表现为大的短记分板停滞),如图 12 所示。
图 12 在 max QKT
计算期间因共享内存访问引起的停滞(箭头指向的数字表示由共享内存等待引起的停滞周期)
解决方案
我们通过直接从 WMMA 片段(即张量核片段)计算 max QKT
,从而绕过共享内存。WMMA 片段的布局取决于 GPU 架构。在此优化中,我们仅为 NVIDIA A100/H100 GPU 启用了此优化。其他 GPU 在计算 max QKT
时仍将使用共享内存。通过绕过共享内存,我们有效消除了由共享内存访问引起的停滞。用于存储 QKT
结果的 C
片段的张量核布局如图 13 所示。
图 13 A100/H100 上 C
片段(QKT
存储)的张量核布局
表 10 优化 7 对 INT4 GQA 的性能提升(按行量化)
批量大小 | 时间 (微秒) | 带宽 (GB/s) | 加速比 | |||||
FD | CU | FD | CU | 相对于 FD | 相对于 CU 基准 | |||
基准 | 优化 7 | 基准 | 优化 7 | |||||
32 | 137 | 143 | 107 | 262 | 250 | 333 | 1.27 | 1.33 |
64 | 234 | 257 | 183 | 305 | 278 | 391 | 1.28 | 1.40 |
128 | 432 | 455 | 333 | 331 | 314 | 430 | 1.30 | 1.37 |
256 | 815 | 866 | 620 | 351 | 331 | 461 | 1.31 | 1.40 |
512 | 1581 | 1659 | 1206 | 362 | 345 | 475 | 1.31 | 1.38 |
表 11 优化 7 对 INT4 GQA 的性能提升(分组式量化,组数 = 4)
批量大小 | 时间 (微秒) | 带宽 (GB/s) | 加速比 | |||||
FD | CUDA_WMMA | FD | CUDA_WMMA | 相对于 FD | 相对于 CUDA_WMMA 优化 6 | |||
优化 6 | 优化 7 | 优化 6 | 优化 7 | |||||
32 | 129 | 116 | 111 | 325 | 364 | 380 | 1.17 | 1.04 |
64 | 219 | 195 | 187 | 385 | 431 | 449 | 1.17 | 1.04 |
128 | 392 | 347 | 333 | 429 | 484 | 506 | 1.18 | 1.04 |
256 | 719 | 638 | 615 | 468 | 527 | 547 | 1.17 | 1.04 |
512 | 1375 | 1225 | 1184 | 489 | 550 | 569 | 1.16 | 1.03 |
优化 8:将 FP32->BF16 结果直接写入 P
片段(A100/H100 特有)
问题分析
在为 P
片段进行 FP32-BF16 转换期间,kernel 从共享内存加载 FP32 数据,进行转换,然后将 BF16 数据存储回共享内存。此外,此转换需要多次线程块同步(__syncthreads()
)。
解决方案
由于 kernel 的数据划分设计,每个 warp 只对 P
片段执行一次遍历。因此,我们无需将转换结果写回共享内存以供将来使用。为了避免将 BF16 数据写入共享内存以及线程块同步,我们让每个 warp 从共享内存加载 P
WMMA 片段的 FP32 数据,进行转换,然后将 BF16 数据直接写入 P
片段。
注意,此优化仅应用于 NVIDIA A100 和 H100 GPU,因为 WMMA 片段布局依赖于架构。对于非 A100/H100 GPU,kernel 将回退到原始路径。
P
片段的张量核布局如图 14 所示。注意,此布局仅适用于 NVIDIA A100/H100 GPU。
图 14 A100/H100 上 P
片段的张量核布局
表 12 INT4 GQA (行式量化) 的优化 8 性能
批量大小 | 时间 (微秒) | 带宽 (GB/s) | 加速比 | |||||
FD | CU | FD | CU | 相对于 FD | 相对于 CU 基准 | |||
基准 | 优化 8 (Opt 8) | 基准 | 优化 8 (Opt 8) | |||||
32 | 137 | 143 | 101 | 262 | 250 | 353 | 1.35 | 1.41 |
64 | 234 | 257 | 174 | 305 | 278 | 410 | 1.34 | 1.47 |
128 | 432 | 455 | 317 | 331 | 314 | 451 | 1.36 | 1.43 |
256 | 815 | 866 | 590 | 351 | 331 | 485 | 1.38 | 1.47 |
512 | 1581 | 1659 | 1143 | 362 | 345 | 501 | 1.38 | 1.45 |
表 13 INT4 GQA (组式量化,组数 = 4) 的优化 8 性能
批量大小 | 时间 (微秒) | 带宽 (GB/s) | 加速比 | |||||
FD | CUDA_WMMA | FD | CUDA_WMMA | 相对于 FD | 相对于 CUDA_WMMA 优化 6 | |||
优化 6 | 优化 8 (Opt 8) | 优化 6 | 优化 8 (Opt 8) | |||||
32 | 129 | 116 | 106 | 325 | 364 | 396 | 1.22 | 1.09 |
64 | 219 | 195 | 180 | 385 | 431 | 467 | 1.21 | 1.08 |
128 | 392 | 347 | 319 | 429 | 484 | 528 | 1.23 | 1.09 |
256 | 719 | 638 | 596 | 468 | 527 | 565 | 1.21 | 1.07 |
512 | 1375 | 1225 | 1138 | 489 | 550 | 591 | 1.21 | 1.08 |
优化 9:混洗 (Swizzle) P 共享内存布局 (A100/H100 特定)
问题分析
我们在加载 P
时观察到大量的共享内存 bank 冲突。bank 冲突的数量取决于内存访问步长。例如,对于 split-Ks = 32 和最大序列长度 = 8192 的情况,我们观察到在并行访问时,32 个 bank 中只有 4 个被访问(内存访问步长 = 256)。从图 14 中可以看出,当所有线程访问元素 0 时,具有相同 threadIdx.x % 4
的线程访问同一个 bank。
图 15 混洗前的共享内存中的 P 片段
解决方案
我们通过混洗 (swizzle) 共享内存中 P
加载/存储的布局,以避免 bank 冲突。换句话说,我们使用混洗后的布局存储 QKT
的结果(C
片段)并加载它们(P
片段)。此外,我们不再使用依赖于每个线程块的 token 数量的原始内存访问步长,而是使用片段的列大小作为步长,该步长是恒定的。因此,P
片段的加载和存储始终是连续的。
C 和 P 片段的新布局如图 16 所示。通过新布局,可以保证 16 个 bank 被并行访问,如图 17 所示。
图 16 C 和 P 片段的混洗布局
图 17 混洗后的共享内存中的 P 片段
表 14 INT4 GQA (行式量化) 的优化 9 性能
批量大小 | 时间 (微秒) | 带宽 (GB/s) | 加速比 | |||||
FD | CU | FD | CU | 相对于 FD | 相对于 CU 基准 | |||
基准 | 优化 9 (Opt 9) | 基准 | 优化 9 (Opt 9) | |||||
32 | 137 | 143 | 98 | 262 | 250 | 365 | 1.39 | 1.46 |
64 | 234 | 257 | 167 | 305 | 278 | 429 | 1.41 | 1.54 |
128 | 432 | 455 | 299 | 331 | 314 | 479 | 1.45 | 1.52 |
256 | 815 | 866 | 549 | 351 | 331 | 521 | 1.48 | 1.58 |
512 | 1581 | 1659 | 1060 | 362 | 345 | 540 | 1.49 | 1.56 |
表 15 INT4 GQA (组式量化,组数 = 4) 的优化 9 性能
批量大小 | 时间 (微秒) | 带宽 (GB/s) | 加速比 | |||||
FD | CUDA_WMMA | FD | CUDA_WMMA | 相对于 FD | 相对于 CUDA_WMMA 优化 6 | |||
优化 6 | 优化 9 (Opt 9) | 优化 6 | 优化 9 (Opt 9) | |||||
32 | 129 | 116 | 105 | 325 | 364 | 400 | 1.23 | 1.10 |
64 | 219 | 195 | 174 | 385 | 431 | 484 | 1.26 | 1.12 |
128 | 392 | 347 | 302 | 429 | 484 | 558 | 1.30 | 1.15 |
256 | 719 | 638 | 560 | 468 | 527 | 601 | 1.28 | 1.14 |
512 | 1375 | 1225 | 1065 | 489 | 550 | 632 | 1.29 | 1.15 |
优化 10:为 INT4 反量化填充共享内存
问题分析
一旦内核从全局内存读取 INT4 的 K
或 V
缓存,它会执行反量化并将结果 (BF16) 存储在共享内存中。然后,BF16 数据从共享内存加载到 WMMA 片段(通过 WMMA 接口)。我们观察到 K
和 V
访问都存在大量 bank 冲突。例如,对于 K
存储,32 个 bank 中只有 4 个被并行访问。对于 K
加载,16 个 bank 被并行访问。对于 V
存储和加载也发生同样的情况。参见解决方案部分的图。
解决方案
我们对共享内存进行填充以减少 bank 冲突。具体来说,我们将每一行填充 2。也就是说,K
的行步长变为 F_K
+ 2,V 的行步长变为 F_N
+ 2(F_K
和 F_N
分别是 K
和 V
WMMA 片段的固定宽度)。通过这种优化,我们能够将 bank 冲突减少 1.8 倍,如图 18 所示。
图 18 优化 10 前后的 Bank 冲突
优化 10 后,对于 K
存储,32 个 bank 被并行访问(如图 19 所示),而对于 K
加载,29 个 bank 被并行访问(如图 20 所示)。
图 19 K 片段存储共享内存布局,带填充和不带填充
图 20 K 片段加载共享内存布局,带填充和不带填充
表 16 INT4 GQA (行式量化) 的优化 10 性能
批量大小 | 时间 (微秒) | 带宽 (GB/s) | 加速比 | |||||
FD | CU | FD | CU | 相对于 FD | 相对于 CU 基准 | |||
基准 | 优化 10 (Opt 10) | 基准 | 优化 10 (Opt 10) | |||||
32 | 137 | 143 | 94 | 262 | 250 | 380 | 1.45 | 1.52 |
64 | 234 | 257 | 151 | 305 | 278 | 475 | 1.55 | 1.71 |
128 | 432 | 455 | 266 | 331 | 314 | 538 | 1.63 | 1.71 |
256 | 815 | 866 | 489 | 351 | 331 | 586 | 1.67 | 1.77 |
512 | 1581 | 1659 | 930 | 362 | 345 | 616 | 1.70 | 1.79 |
表 17 INT4 GQA (组式量化,组数 = 4) 的优化 10 性能
批量大小 | 时间 (微秒) | 带宽 (GB/s) | 加速比 | |||||
FD | CUDA_WMMA | FD | CUDA_WMMA | 相对于 FD | 相对于 CUDA_WMMA 优化 6 | |||
优化 6 | 优化 10 (Opt 10) | 优化 6 | 优化 10 (Opt 10) | |||||
32 | 129 | 116 | 99 | 325 | 364 | 425 | 1.31 | 1.17 |
64 | 219 | 195 | 161 | 385 | 431 | 523 | 1.36 | 1.21 |
128 | 392 | 347 | 282 | 429 | 484 | 598 | 1.39 | 1.23 |
256 | 719 | 638 | 509 | 468 | 527 | 662 | 1.41 | 1.25 |
512 | 1375 | 1225 | 965 | 489 | 550 | 698 | 1.43 | 1.27 |
性能评估
微基准测试结果
我们还使用我们优化的内核评估了 BF16 GQA 的性能(如表 19 所示)。对于 BF16,CU 的性能通常仍低于 FD 和 FA。这是预期结果,因为我们的优化主要针对 INT4。
虽然 INT4 GQA 仍然不如 BF16 GQA 高效(参见达到的带宽),但重要的是要注意,将 FD BF16 GQA 性能与 CU INT4 GQA 性能进行比较时,我们可以看到 INT4 的延迟小于 BF16。
表 19 CU 优化后 BF16 GQA 和 INT GQA 的性能
在 A100 上
时间 (微秒) | BF16 GQA | INT4 GQA | ||||||
批量大小 | FD | FA | 优化前 CU | 优化后 CU | FD | FA | 优化前 CU | 优化后 CU |
32 | 139 | 133 | 183 | 163 | 137 | - | 143 | 94 |
64 | 245 | 229 | 335 | 276 | 234 | - | 257 | 151 |
128 | 433 | 555 | 596 | 517 | 432 | - | 455 | 266 |
256 | 826 | 977 | 1127 | 999 | 815 | - | 866 | 489 |
512 | 1607 | 1670 | 2194 | 1879 | 1581 | - | 1659 | 930 |
有效带宽 (GB/s) | BF16 GQA | INT4 GQA | ||||||
批量大小 | FD | FA | 优化前 CU | 优化后 CU | FD | FA | 优化前 CU | 优化后 CU |
32 | 965 | 1012 | 736 | 824 | 262 | - | 250 | 380 |
64 | 1097 | 1175 | 802 | 972 | 305 | - | 278 | 475 |
128 | 1240 | 968 | 901 | 1039 | 331 | - | 314 | 538 |
256 | 1301 | 1100 | 954 | 1075 | 351 | - | 331 | 586 |
512 | 1338 | 1287 | 980 | 1144 | 362 | - | 345 | 616 |
在 H100 上
时间 (微秒) | BF16 GQA | INT4 GQA | ||||||
批量大小 | FD | FA | 优化前 CU | 优化后 CU | FD | FA | 优化前 CU | 优化后 CU |
32 | 91 | 90 | 114 | 100 | 70 | - | 96 | 64 |
64 | 148 | 146 | 200 | 183 | 113 | - | 162 | 101 |
128 | 271 | 298 | 361 | 308 | 205 | - | 294 | 170 |
256 | 515 | 499 | 658 | 556 | 389 | - | 558 | 306 |
512 | 1000 | 1011 | 1260 | 1066 | 756 | - | 1066 | 575 |
有效带宽 (GB/s) | BF16 GQA | INT4 GQA | ||||||
批量大小 | FD | FA | 优化前 CU | 优化后 CU | FD | FA | 优化前 CU | 优化后 CU |
32 | 1481 | 1496 | 1178 | 1341 | 511 | - | 371 | 560 |
64 | 1815 | 1840 | 1345 | 1470 | 631 | - | 443 | 710 |
128 | 1982 | 1802 | 1487 | 1743 | 699 | - | 487 | 844 |
256 | 2087 | 2156 | 1634 | 1934 | 736 | - | 513 | 935 |
512 | 2150 | 2127 | 1706 | 2015 | 757 | - | 537 | 996 |
端到端 (E2E) 结果
我们在 8 个 H100 GPU 上评估了 Llama 2 70B 模型中我们优化的 INT4 GQA 内核的端到端性能。我们运行了完整的模型,但仅报告了解码延迟。我们使用 FP8 FFN (前馈网络) 来强调解码阶段的注意力性能。我们将批量大小从 1 变化到 256,上下文长度从 2,048 (2K) 变化到 16,384 (16K)。端到端性能结果如下图所示。
图 21 Meta Llama 2 解码延迟 (毫秒) 比较 (BF16 GQA 在大批量配置下内存不足)
代码
如果您有兴趣,请查看我们的代码 此处。如果您有任何问题,请随时在 GitHub 上开启 issue,我们将很乐意提供帮助。欢迎您的贡献!