作者:Sarunya Pumma, Jongsoo Park, Jianyu Huang, Amy Yang, Jaewon Lee, Daniel Haziza, Grigory Sizov, Jeremy Reizenstein, Jeff Johnson, Ying Zhang

一种高效的采用低精度 KV 缓存的解码分组查询注意力

引言

生成式 AI 以其像人类一样生成内容的能力席卷全球。许多此类生成式 AI 工具都由大型语言模型(LLM)驱动,例如 Meta 的 Llama 模型和 OpenAI 的 ChatGPT。LLM 面临的主要挑战之一是支持较长的“上下文长度”(也称为“序列长度”)。上下文长度是指模型用于理解输入上下文和生成响应的 token 数量。更长的上下文长度通常意味着响应具有更高的精度和质量。然而,长上下文长度对计算和内存要求很高。这主要是由于以下原因:

  • 注意力层的计算复杂度随上下文长度线性增加(增长率取决于注意力算法)。因此,在使用长上下文长度时,注意力层可能成为瓶颈,尤其是在预填充阶段,该阶段的注意力计算受计算能力限制。
  • KV 缓存的大小随上下文长度线性增长,从而对内存需求造成更大压力,并因此减缓了本已受内存限制的注意力解码速度。此外,由于内存容量有限,当 KV 缓存变大时,批量大小会减小,这通常会导致吞吐量下降。

与上面提到的其他问题相比,计算复杂度的增长更难解决。解决 KV 缓存大小增长问题的一种方法是使用低精度 KV 缓存。从我们的实验来看,在 Meta Llama 2 推理的解码阶段,分组式 INT4 量化在精度方面与 BF16 KV 缓存提供了可比的结果。然而,尽管在注意力解码层读取的数据量减少了 4 倍,但我们并未观察到任何延迟改善。这意味着 INT4 注意力在利用宝贵的 HBM 带宽方面比 BF16 注意力效率低 4 倍。

在本文中,我们将讨论我们对 INT4 GQA(分组查询注意力——我们在 LLM 推理阶段使用的注意力层)应用的 CUDA 优化,以将其性能提高多达 在 NVIDIA A100 GPU 上达到 1.8 倍在 NVIDIA H100 GPU 上达到 1.9 倍

  • 优化的 CUDA INT4 GQA 在性能上优于 INT4 Flash-Decoding GQA(我们在上述实验中使用的性能最佳的 INT4 GQA),优于幅度为 在 A100 上为 1.4 倍至 1.7 倍,以及 在 H100 上为 1.09 倍至 1.3 倍。
  • 优化的 CUDA INT4 GQA 性能优于 BF16 Flash-Decoding GQA,优于幅度为 在 A100 上为 1.5 倍至 1.7 倍,在 H100 上为 1.4 倍至 1.7 倍。

背景

用于 LLM 推理的 GQA

分组查询注意力 (GQA) 是多头注意力(MHA)的一种变体,其中每个 KV 缓存头在一组查询头之间共享。我们的 LLM 推理在预填充和解码阶段都采用 GQA 作为注意力层,以减少对 KV 缓存的容量需求。我们在推理中使用多个 GPU,其中 KV 缓存和查询头分布在不同的 GPU 上。每个 GPU 运行一个注意力层,包含一个 KV 头和一组 Q 头。因此,从单个 GPU 的角度来看,GQA 组件也可以描述为 MQA(多查询注意力)

解码 GQA 的简化工作流程如图 1 所示。GQA 接收三个主要输入:输入查询(表示为 Q)、K 缓存(表示为 K)和 V 缓存(表示为 V)。我们当前的 GQA 推理对 QKV 使用 BF16 类型。

  • Q 是一个形状为 (B, 1, HQ, D) 的 4D BF16 张量
  • K 是一个形状为 (B, Tmax, HKV, D) 的 4D BF16 张量
  • V 是一个形状为 (B, Tmax, HKV, D) 的 4D BF16 张量

其中

  • B 是批量大小(输入提示的数量)
  • HQ 是查询头的数量
  • HKV 是 KV 头的数量(HQ 必须能被 HKV 整除)
  • Tmax 是最大上下文长度
  • D 是头维度(固定为 128)

GQA 简单来说就是 bmm(softmax(bmm(Q, KT) / sqrt(D)), V)。这产生一个输出张量(表示为 O),它是一个形状与 Q 相同的 4D BF16 张量。注意,矩阵乘法使用 BF16 执行,但累加和 softmax 在 FP32 中进行。由于 KV 缓存是 BF16 类型,我们称之为“BF16 GQA”。

Figure 1: The simplified workflow of BF16 GQA for LLM inference

图 1 用于 LLM 推理的 BF16 GQA 的简化工作流程

INT4 GQA

为了进一步减小 KV 缓存的大小,我们探索了使用 INT4 代替 BF16 作为 KV 缓存的可能性。我们通过计算 INT4 GQA 的计算强度(CI)并将其与 BF16 GQA 的计算强度进行比较来估计潜在的性能提升,因为 CI 代表每字节的 FLOPS。我们计算 QKTPV 的 CI(如公式 1 所示),因为它们将 KV 缓存作为操作数。注意,我们忽略 Q 的加载,因为它与 KV 缓存相比可以忽略不计。我们也忽略任何不在全局内存上的中间数据加载/存储。因此,CI 只考虑计算 FLOPS 和 KV 缓存加载。

Equation 1

公式 (1)

假设 HQ = 8 且 HKV = 1,BF16 KV 缓存的 CI 为 8,而 INT4 KV 缓存的 CI 为 32。这些 CI 表明 BF16 和 INT4 GQA 都受内存限制(A100 和 H100 的 BF16 张量核的峰值 CI 分别为 312 TF / 2 TB/s = 141990 TF / 3.35 TB/s = 269;请注意,这些 TF 数字未考虑稀疏性)。此外,使用 INT4 KV 缓存,我们应该期望性能比 BF16 GQA 提升多达 4 倍。

为了在 GQA 中启用 INT4 KV 缓存支持,我们可以在将 KV 缓存传递给 BF16 GQA 算子之前,将其从 INT4 反量化为 BF16。然而,由于 KV 缓存通常很大,将其从全局内存复制或复制到全局内存可能会耗费大量开销。此外,解码 GQA 是一种受内存限制的操作(内存单元比计算单元利用得更重)。图 2 显示了 xFormers 中的 FMHA CUTLASS BF16 GQA kernel 的 NCU 性能分析结果,它是 GQA 最先进的实现之一。从图上看,显然内存是瓶颈。

Figure 2: The NCU profile of the FMHA CUTLASS BF16 kernel in xFormers

图 2 xFormers 中的 FMHA CUTLASS BF16 kernel 的 NCU 性能分析结果

一种更有效的替代方案是将 INT4 反量化与 GQA 操作融合(如图 3 所示)。换句话说,让 GQA 直接读取 INT4 KV 缓存,并在 kernel 中执行 INT4 到 BF16 的转换。这一改变可以潜在地减少 KV 缓存所需的全局内存读取量,这可能会导致延迟降低。我们称之为“INT4 GQA”。

Figure 3: The workflow of fused INT4 GQA

图 3 融合 INT4 GQA 的工作流程

我们在下表中列出了 GQA 最先进的实现及其特性(表 1)。

表 1 最先进的 GQA 实现

实现 表示 BF16 GQA 融合 INT4 GQA
Flash-Decoding (Triton 实现) FD
Flash Attention (v2.3.3) FA
CUDA 基准 CU

所有实现(CU 除外)都支持 split-K 和非 split-K。CU 只有 split-K 实现。只有 FA 在后端具有启发式算法,用于确定是运行 split-K 还是非 split-K kernel。对于其他实现,用户必须明确选择要运行的版本。在本文中,我们重点关注长上下文长度(在我们的实验中,我们使用上下文长度为 8192),因此尽可能选择 split-K 版本。

作为基准,我们在 NVIDIA A100 和 H100 GPU 上测量了最先进 GQA 实现的性能。延迟(以微秒为单位的时间)和实现的带宽(GB/s)报告在表 2 中。注意,我们运行了一系列 split-K(从 2 到 128 个 split),并报告了每个实现的最佳性能。对于所有实验,我们使用上下文长度 8192。对于 INT4 GQA,我们使用了按行量化(即量化组数 = 1)。

表 2 基准 GQA 性能

在 A100 上

时间 (微秒) BF16 GQA INT4 GQA
批量大小 FD FA CU FD FA CU
32 139 133 183 137 - 143
64 245 229 335 234 - 257
128 433 555 596 432 - 455
256 826 977 1127 815 - 866
512 1607 1670 2194 1581 - 1659
有效带宽 (GB/s) BF16 GQA INT4 GQA
批量大小 FD FA CU FD FA CU
32 965 1012 736 262 - 250
64 1097 1175 802 305 - 278
128 1240 968 901 331 - 314
256 1301 1100 954 351 - 331
512 1338 1287 980 362 - 345

在 H100 上

时间 (微秒) BF16 GQA INT4 GQA
批量大小 FD FA CU FD FA CU
32 91 90 114 70 - 96
64 148 146 200 113 - 162
128 271 298 361 205 - 294
256 515 499 658 389 - 558
512 1000 1011 1260 756 - 1066
有效带宽 (GB/s) BF16 GQA INT4 GQA
批量大小 FD FA CU FD FA CU
32 1481 1496 1178 511 - 371
64 1815 1840 1345 631 - 443
128 1982 1802 1487 699 - 487
256 2087 2156 1634 736 - 513
512 2150 2127 1706 757 - 537

首先,我们来讨论 BF16 GQA 的性能:在所有实现中,CU 的性能排名最后。FD 和 FA 的性能相当。当批量大小小于或等于 64 时,FA 利用 split-K kernel,性能略优于 FD。然而,当批量大小大于 64 时,FD 性能更佳。

同样的趋势也适用于 INT4 GQA。然而,由于 FA 不支持 INT4 KV 缓存,我们没有测量其性能。在所有情况下,FD 都优于 CU。

比较 BF16 和 INT4 GQA 的 FD 延迟时,我们发现它们几乎相同。这表明 INT4 GQA 的效率非常低,INT4 GQA 可实现的带宽显著低于 BF16 GQA,这进一步证实了这一点。查看 CU 的性能时,同样的趋势也适用。

使用张量核的 CUDA INT4 GQA 实现

在本节中,我们将简要介绍我们的基准实现,即使用张量核的 CUDA INT4 GQA(CU)。每个线程块仅处理一个 KV 头和来自一个输入提示的一组查询头。因此,每个线程块执行 mm(softmax(mm(Q, KT) / sqrt(D)), V);请注意,这里执行的是 mm 而非 bmm。此外,由于这是一个 split-K 实现,KV 缓存中的 token 分割到不同的线程块中。注意,每个线程块包含 4 个 warp(对于 NVIDIA A100 和 H100 GPU,每个 warp 包含 32 个线程)。每个线程块中的工作分配给不同的 warp。在每个 warp 中,我们使用 WMMA API 在张量核上计算矩阵乘法。图 4 演示了 CU 中的工作划分。

Figure 4: CU work partitioning

图 4 CU 工作划分

优化使用张量核的 INT4 GQA CUDA Kernel

在本文中,我们将讨论我们对使用张量核的 CUDA INT4 GQA 实现(CU)所应用的优化。理想目标是基于前一节中的 CI 分析,将 INT4 GQA 的性能提高 4 倍。注意,当上下文长度很长时,查询大小与 KV 缓存大小相比可以忽略不计。

在我们的分析中,我们使用 NVIDIA Nsight Compute (NCU) 作为主要的性能分析器。我们通用的瓶颈消除方法是最小化停滞周期。我们对 INT4 GQA 应用了 10 项优化,其中三项是 NVIDIA A100/H100 GPU 特有的。这些优化是众所周知的 CUDA 优化技术,可以推广到许多应用中。

值得注意的是,我们选择优化 CUDA 实现而非 Flash-Decoding 实现(FD,基于 Triton)的原因在于,使用 CUDA,我们可以更好地控制底层指令的生成方式。我们应用的许多优化技术,例如直接操作张量核片段(优化 7-9),无法通过 Triton 实现,因为它不向开发者暴露底层细节。然而,这些优化可以集成到基于编译器的解决方案中,使这些优化可用于更广泛的算子,这确实是我们未来计划的一部分。

优化 1:展开 K 加载

问题分析

NCU 性能分析结果显示,在 K 加载期间,只有 2 个全局加载,随后在 dequantize_permuted_int4 处出现内存停滞。内存停滞是长的记分板停滞,这表明在等待全局内存访问。这表明 kernel 没有发出足够的内存加载:

以隐藏全局加载延迟。kernel 发出数据加载指令后,立即等待消费数据,从而暴露了全局加载延迟。停滞情况如图 5 所示。

Figure 5: K loading before unrolling

图 5 展开前的 K 加载(箭头指向的数字表示由全局内存等待引起的停滞周期)

解决方案

在基准实现中,我们使用 uint32_t 在一次加载中加载 8 个 INT4 K 值,并且在每次迭代中执行 2 次 uint32_t 加载,即 16 个 INT4 K 值。为了更好地隐藏全局加载延迟,我们在消费 dequantize_permuted_int4 中的 K 值之前,发出 8 次 uint32_t 加载,而非 2 次。这使得编译器能够展开加载并重新排序指令,从而更好地隐藏全局加载延迟。图 6 显示了展开后 K 加载的 NCU 性能分析结果。比较图 5 和图 6,我们通过展开 K 加载,有效减少了停滞周期。

Figure 6: K loading after unrolling

图 6 展开后的 K 加载(箭头指向的数字表示由全局内存等待引起的停滞周期)

结果

表 3 优化 1 对 INT4 GQA 的性能提升(按行量化)

批量大小 时间 (微秒) 带宽 (GB/s) 加速比
FD CU FD CU 相对于 FD 相对于 CU 基准
基准 优化 1 基准 优化 1
32 137 143 134 262 250 267 1.02 1.07
64 234 257 237 305 278 302 0.99 1.09
128 432 455 422 331 314 339 1.02 1.08
256 815 866 806 351 331 355 1.01 1.07
512 1581 1659 1550 362 345 369 1.02 1.07

优化 2:改进 P 类型转换(FP32->BF16)

问题分析

由于 softmax(bmm(Q, KT) / sqrt(D)) 的结果是 FP32(在图 3 中表示为 P),kernel 必须在将其馈送到下一个 bmm 计算之前将 P 从 FP32 转换为 BF16。kernel 通过将 FP32 数据从共享内存的一个位置复制到共享内存的另一个位置来执行 P 的 FP32 到 BF16 转换。这导致在访问共享内存期间出现停滞(如图 7 所示),这可能由 (1) 共享内存间接寻址;以及 (2) 共享内存 bank 冲突引起,因为每个线程访问一个 16 位元素(因此,两个线程可以同时访问同一个 memory bank)。

Figure 7: P type casting before Optimization 2

图 7 优化 2 之前的 P 类型转换(箭头指向的数字表示由共享内存等待引起的停滞周期)

解决方案

我们使用线程块中的所有线程进行就地类型转换。每个线程操作两个连续的元素,以避免在存储 BF16 时出现共享内存 bank 冲突。所有线程同时处理同一个头(h),以保证转换的正确性。就地转换步骤如下:

  1. 每个线程将同一个头的 2 个 FP32 token 元素从共享内存加载到寄存器中
  2. 调用 __syncthreads() 以确保每个线程完成数据读取
  3. 每个线程将其数据转换为 2 个 BF16 token 元素,然后将结果存储回同一个共享内存

我们对实现应用的一些优化:

  • 使用向量类型(特别是 nv_bfloat2
  • 展开数据加载/存储,即在调用 __syncthreads() 之前执行多次加载,并在调用 __syncthreads() 之后执行多次存储

应用此优化后,在 P 类型转换期间未观察到长停滞,如图 8 所示。

Figure 8: P type casting after Optimization 2

图 8 优化 2 之后的 P 类型转换(箭头指向的数字表示由共享内存等待引起的停滞周期)

潜在问题

由于我们通过使用寄存器作为中间存储来展开数据加载/存储,因此每个线程的寄存器数量增加,导致占用率降低。

结果

表 4 优化 2 对 INT4 GQA 的性能提升(按行量化)

批量大小 时间 (微秒) 带宽 (GB/s) 加速比
FD CU FD CU 相对于 FD 相对于 CU 基准
基准 优化 2 基准 优化 2
32 137 143 126 262 250 285 1.09 1.14
64 234 257 221 305 278 324 1.06 1.16
128 432 455 395 331 314 362 1.09 1.15
256 815 866 749 351 331 382 1.09 1.16
512 1581 1659 1435 362 345 399 1.10 1.16

优化 3:移除 max QKT 计算中的局部内存使用

问题分析

在 softmax 计算期间,kernel 必须计算每个头的 max QKT。它使用一个临时的“线程局部”存储来存储每个线程的 max QKT 结果(每个头一个 float 值)。根据编译器,线程局部存储可以分配在寄存器(片上)或局部内存(片外 == 全局内存)上。不幸的是,在基准中,线程局部存储位于局部内存中,这比寄存器慢得多(如图 9 所示)。我们怀疑这是因为编译器无法在编译时确定线程局部存储的索引(因为 kernel 中的头数量(H)是运行时变量)。像访问寄存器一样访问局部内存可能会损害 kernel 的性能。

Figure 9: Local memory access during max QKT computation

图 9 在 max QKT 计算期间的局部内存访问

解决方案

我们意识到,我们不需要每个线程使用 H(头数量)个 float 作为临时存储,因为每个线程可以只为一个头计算 max QKT,而非为所有头计算。因此,我们每个线程只需一个 float,这可以很容易地存储在寄存器中。为了在 warp 之间累加最大值结果,我们使用共享内存。此优化消除了在 max QKT 计算期间对局部内存的使用。

结果

表 5 优化 3 对 INT4 GQA 的性能提升(按行量化)

批量大小 时间 (微秒) 带宽 (GB/s) 加速比
FD CU FD CU 相对于 FD 相对于 CU 基准
基准 优化 3 基准 优化 3
32 137 143 119 262 250 300 1.14 1.20
64 234 257 206 305 278 348 1.14 1.25
128 432 455 368 331 314 389 1.17 1.24
256 815 866 696 351 331 411 1.17 1.24
512 1581 1659 1338 362 345 428 1.18 1.24

优化 4:移除行求和中的局部内存使用

问题分析

类似于 优化 3,在 softmax 计算的行求和期间也观察到局部内存使用问题。由于局部内存是片外内存,像访问寄存器一样访问它可能会损害 kernel 的性能。

解决方案:

我们对行求和计算应用了与 max QKT 计算相同的解决方案。即让每个线程只计算一个头的行求和,这每个线程只需一个 float。这消除了对局部内存的需求。

结果

表 6 优化 4 对 INT4 GQA 的性能提升(按行量化)

批量大小 时间 (微秒) 带宽 (GB/s) 加速比
FD CU FD CU 相对于 FD 相对于 CU 基准
基准 优化 4 基准 优化 4
32 137 143 118 262 250 302 1.15 1.21
64 234 257 204 305 278 351 1.15 1.26
128 432 455 364 331 314 393 1.19 1.25
256 815 866 688 351 331 416 1.18 1.26
512 1581 1659 1328 362 345 431 1.19 1.25

优化 5:为 V 加载添加预取

问题分析

加载 V 时也观察到与加载 K 相同的问题。也就是说,kernel 发出数据加载指令后,立即等待消费数据,从而暴露了全局加载延迟。然而,当使用上述展开技术时,编译器会将临时缓冲区分配到局部内存而不是寄存器中,导致性能大幅下降。

解决方案

我们为 V 加载采用数据预取技术。在消费当前迭代值后,我们立即加载下一迭代的 V 值。这使得数据加载与 PV 计算重叠,从而提高了 kernel 性能。

结果

表 7 优化 5 对 INT4 GQA 的性能提升(按行量化)

批量大小 时间 (微秒) 带宽 (GB/s) 加速比
FD CU FD CU 相对于 FD 相对于 CU 基准
基准 优化 5 基准 优化 5
32 137 143 109 262 250 327 1.25 1.31
64 234 257 194 305 278 370 1.21 1.33
128 432 455 345 331 314 414 1.25 1.32
256 815 866 649 351 331 441 1.26 1.33
512 1581 1659 1244 362 345 460 1.27 1.33

优化 6:添加分组式 INT4(组数 = 4)并使用向量加载

问题分析

在此优化之前,CU 仅支持按行 INT4 量化。也就是说,每行中的每列共享相同的缩放因子。每行的缩放因子存储在每行的前 4 字节中,如图 10 所示。在 kernel 中,每个线程一次只加载一行。由于每行包含 68 字节(4 字节用于缩放因子,64 字节用于数据),因此无法保证每行与任何向量类型的大小对齐。因此,无法使用向量加载来加载 KV 缓存。

Figure 10: The layout of each row of INT4 KV cache with row-wise quantization

图 10 按行量化 INT4 KV 缓存每行的布局

解决方案

我们实现了对分组式 INT4 量化(组数 = 4)的支持。在这种情况下,KV 缓存张量每行中的列被分成 4 个相等的组。同一组内的列共享相同的缩放因子进行量化/反量化。INT4 KV 缓存的数据布局如图 11 所示。所有组的缩放因子被序列化并存储在每行的开头。INT4 数据也被序列化并布局在缩放因子旁边。

由于每行中的字节数现在变为 80 字节,我们可以使用向量类型(即在我们的情况下为 uint2)来加载数据。(我们使用 uint4,因为每个线程一次只加载 16 个 INT4 值,这是由于张量核片段大小的限制。)向量加载通常优于标量加载,因为它不会导致额外的字节加载。

Figure 11: The layout of each row of INT4 KV cache with row-wise quantization

图 11 分组式 INT4 量化(组数 = 4)INT4 KV 缓存每行的布局

结果

表 8 优化 6 对 INT4 GQA 的性能提升(按行量化)

批量大小 时间 (微秒) 带宽 (GB/s) 加速比
FD CU FD CU 相对于 FD 相对于 CU 基准
基准 优化 6 基准 优化 6
32 137 143 111 262 250 322 1.23 1.29
64 234 257 192 305 278 372 1.22 1.34
128 432 455 346 331 314 414 1.25 1.32
256 815 866 642 351 331 446 1.27 1.35
512 1581 1659 1244 362 345 460 1.27 1.33

表 9 优化 6 对 INT4 GQA 的性能提升(分组式量化,组数 = 4)

批量大小 时间 (微秒) 带宽 (GB/s) 加速比
FD CUDA_WMMA FD CUDA_WMMA 相对于 FD
优化 6 优化 6
32 129 116 325 364 1.31
64 219 195 385 431 1.36
128 392 347 429 484 1.39
256 719 638 468 527 1.41
512 1375 1225 489 550 1.43

优化 7:直接从 WMMA 片段计算 max QKT(A100/H100 特有)

问题分析

我们观察到在 max QKT 计算期间,由于访问共享内存而出现大量停滞(表现为大的短记分板停滞),如图 12 所示。

Figure 12: Stalls due to shared memory access during max QKT computation

图 12 在 max QKT 计算期间因共享内存访问引起的停滞(箭头指向的数字表示由共享内存等待引起的停滞周期)

解决方案

我们通过直接从 WMMA 片段(即张量核片段)计算 max QKT,从而绕过共享内存。WMMA 片段的布局取决于 GPU 架构。在此优化中,我们仅为 NVIDIA A100/H100 GPU 启用了此优化。其他 GPU 在计算 max QKT 时仍将使用共享内存。通过绕过共享内存,我们有效消除了由共享内存访问引起的停滞。用于存储 QKT 结果的 C 片段的张量核布局如图 13 所示。

Figure 13: C fragment (QKT storage) tensor core layout on A100/H100

图 13 A100/H100 上 C 片段(QKT 存储)的张量核布局

表 10 优化 7 对 INT4 GQA 的性能提升(按行量化)

批量大小 时间 (微秒) 带宽 (GB/s) 加速比
FD CU FD CU 相对于 FD 相对于 CU 基准
基准 优化 7 基准 优化 7
32 137 143 107 262 250 333 1.27 1.33
64 234 257 183 305 278 391 1.28 1.40
128 432 455 333 331 314 430 1.30 1.37
256 815 866 620 351 331 461 1.31 1.40
512 1581 1659 1206 362 345 475 1.31 1.38

表 11 优化 7 对 INT4 GQA 的性能提升(分组式量化,组数 = 4)

批量大小 时间 (微秒) 带宽 (GB/s) 加速比
FD CUDA_WMMA FD CUDA_WMMA 相对于 FD 相对于 CUDA_WMMA 优化 6
优化 6 优化 7 优化 6 优化 7
32 129 116 111 325 364 380 1.17 1.04
64 219 195 187 385 431 449 1.17 1.04
128 392 347 333 429 484 506 1.18 1.04
256 719 638 615 468 527 547 1.17 1.04
512 1375 1225 1184 489 550 569 1.16 1.03

优化 8:将 FP32->BF16 结果直接写入 P 片段(A100/H100 特有)

问题分析

在为 P 片段进行 FP32-BF16 转换期间,kernel 从共享内存加载 FP32 数据,进行转换,然后将 BF16 数据存储回共享内存。此外,此转换需要多次线程块同步(__syncthreads())。

解决方案

由于 kernel 的数据划分设计,每个 warp 只对 P 片段执行一次遍历。因此,我们无需将转换结果写回共享内存以供将来使用。为了避免将 BF16 数据写入共享内存以及线程块同步,我们让每个 warp 从共享内存加载 P WMMA 片段的 FP32 数据,进行转换,然后将 BF16 数据直接写入 P 片段。

注意,此优化仅应用于 NVIDIA A100 和 H100 GPU,因为 WMMA 片段布局依赖于架构。对于非 A100/H100 GPU,kernel 将回退到原始路径。

P 片段的张量核布局如图 14 所示。注意,此布局仅适用于 NVIDIA A100/H100 GPU。

Figure 14: P fragment tensor core layout on A100/H100

图 14 A100/H100 上 P 片段的张量核布局

表 12 INT4 GQA (行式量化) 的优化 8 性能

批量大小 时间 (微秒) 带宽 (GB/s) 加速比
FD CU FD CU 相对于 FD 相对于 CU 基准
基准 优化 8 (Opt 8) 基准 优化 8 (Opt 8)
32 137 143 101 262 250 353 1.35 1.41
64 234 257 174 305 278 410 1.34 1.47
128 432 455 317 331 314 451 1.36 1.43
256 815 866 590 351 331 485 1.38 1.47
512 1581 1659 1143 362 345 501 1.38 1.45

表 13 INT4 GQA (组式量化,组数 = 4) 的优化 8 性能

批量大小 时间 (微秒) 带宽 (GB/s) 加速比
FD CUDA_WMMA FD CUDA_WMMA 相对于 FD 相对于 CUDA_WMMA 优化 6
优化 6 优化 8 (Opt 8) 优化 6 优化 8 (Opt 8)
32 129 116 106 325 364 396 1.22 1.09
64 219 195 180 385 431 467 1.21 1.08
128 392 347 319 429 484 528 1.23 1.09
256 719 638 596 468 527 565 1.21 1.07
512 1375 1225 1138 489 550 591 1.21 1.08

优化 9:混洗 (Swizzle) P 共享内存布局 (A100/H100 特定)

问题分析

我们在加载 P 时观察到大量的共享内存 bank 冲突。bank 冲突的数量取决于内存访问步长。例如,对于 split-Ks = 32 和最大序列长度 = 8192 的情况,我们观察到在并行访问时,32 个 bank 中只有 4 个被访问(内存访问步长 = 256)。从图 14 中可以看出,当所有线程访问元素 0 时,具有相同 threadIdx.x % 4 的线程访问同一个 bank。

Figure 15: P fragment in shared memory before swizzling

图 15 混洗前的共享内存中的 P 片段

解决方案

我们通过混洗 (swizzle) 共享内存中 P 加载/存储的布局,以避免 bank 冲突。换句话说,我们使用混洗后的布局存储 QKT 的结果(C 片段)并加载它们(P 片段)。此外,我们不再使用依赖于每个线程块的 token 数量的原始内存访问步长,而是使用片段的列大小作为步长,该步长是恒定的。因此,P 片段的加载和存储始终是连续的。

C 和 P 片段的新布局如图 16 所示。通过新布局,可以保证 16 个 bank 被并行访问,如图 17 所示。

Figure 16: The swizzled layouts of C and P fragments

图 16 C 和 P 片段的混洗布局

Figure 17: P fragment in shared memory after swizzling

图 17 混洗后的共享内存中的 P 片段

表 14 INT4 GQA (行式量化) 的优化 9 性能

批量大小 时间 (微秒) 带宽 (GB/s) 加速比
FD CU FD CU 相对于 FD 相对于 CU 基准
基准 优化 9 (Opt 9) 基准 优化 9 (Opt 9)
32 137 143 98 262 250 365 1.39 1.46
64 234 257 167 305 278 429 1.41 1.54
128 432 455 299 331 314 479 1.45 1.52
256 815 866 549 351 331 521 1.48 1.58
512 1581 1659 1060 362 345 540 1.49 1.56

表 15 INT4 GQA (组式量化,组数 = 4) 的优化 9 性能

批量大小 时间 (微秒) 带宽 (GB/s) 加速比
FD CUDA_WMMA FD CUDA_WMMA 相对于 FD 相对于 CUDA_WMMA 优化 6
优化 6 优化 9 (Opt 9) 优化 6 优化 9 (Opt 9)
32 129 116 105 325 364 400 1.23 1.10
64 219 195 174 385 431 484 1.26 1.12
128 392 347 302 429 484 558 1.30 1.15
256 719 638 560 468 527 601 1.28 1.14
512 1375 1225 1065 489 550 632 1.29 1.15

优化 10:为 INT4 反量化填充共享内存

问题分析

一旦内核从全局内存读取 INT4 的 KV 缓存,它会执行反量化并将结果 (BF16) 存储在共享内存中。然后,BF16 数据从共享内存加载到 WMMA 片段(通过 WMMA 接口)。我们观察到 KV 访问都存在大量 bank 冲突。例如,对于 K 存储,32 个 bank 中只有 4 个被并行访问。对于 K 加载,16 个 bank 被并行访问。对于 V 存储和加载也发生同样的情况。参见解决方案部分的图。

解决方案

我们对共享内存进行填充以减少 bank 冲突。具体来说,我们将每一行填充 2。也就是说,K 的行步长变为 F_K + 2,V 的行步长变为 F_N + 2(F_KF_N 分别是 KV WMMA 片段的固定宽度)。通过这种优化,我们能够将 bank 冲突减少 1.8 倍,如图 18 所示。

Figure 18: Bank conflicts before and after Optimization 10

图 18 优化 10 前后的 Bank 冲突

优化 10 后,对于 K 存储,32 个 bank 被并行访问(如图 19 所示),而对于 K 加载,29 个 bank 被并行访问(如图 20 所示)。

Figure 19: K fragment store shared memory layout without and with padding

图 19 K 片段存储共享内存布局,带填充和不带填充

Figure 20: K fragment load shared memory layout without and with padding

图 20 K 片段加载共享内存布局,带填充和不带填充

表 16 INT4 GQA (行式量化) 的优化 10 性能

批量大小 时间 (微秒) 带宽 (GB/s) 加速比
FD CU FD CU 相对于 FD 相对于 CU 基准
基准 优化 10 (Opt 10) 基准 优化 10 (Opt 10)
32 137 143 94 262 250 380 1.45 1.52
64 234 257 151 305 278 475 1.55 1.71
128 432 455 266 331 314 538 1.63 1.71
256 815 866 489 351 331 586 1.67 1.77
512 1581 1659 930 362 345 616 1.70 1.79

表 17 INT4 GQA (组式量化,组数 = 4) 的优化 10 性能

批量大小 时间 (微秒) 带宽 (GB/s) 加速比
FD CUDA_WMMA FD CUDA_WMMA 相对于 FD 相对于 CUDA_WMMA 优化 6
优化 6 优化 10 (Opt 10) 优化 6 优化 10 (Opt 10)
32 129 116 99 325 364 425 1.31 1.17
64 219 195 161 385 431 523 1.36 1.21
128 392 347 282 429 484 598 1.39 1.23
256 719 638 509 468 527 662 1.41 1.25
512 1375 1225 965 489 550 698 1.43 1.27

性能评估

微基准测试结果

我们还使用我们优化的内核评估了 BF16 GQA 的性能(如表 19 所示)。对于 BF16,CU 的性能通常仍低于 FD 和 FA。这是预期结果,因为我们的优化主要针对 INT4。

虽然 INT4 GQA 仍然不如 BF16 GQA 高效(参见达到的带宽),但重要的是要注意,将 FD BF16 GQA 性能与 CU INT4 GQA 性能进行比较时,我们可以看到 INT4 的延迟小于 BF16

表 19 CU 优化后 BF16 GQA 和 INT GQA 的性能

在 A100 上

时间 (微秒) BF16 GQA INT4 GQA
批量大小 FD FA 优化前 CU 优化后 CU FD FA 优化前 CU 优化后 CU
32 139 133 183 163 137 - 143 94
64 245 229 335 276 234 - 257 151
128 433 555 596 517 432 - 455 266
256 826 977 1127 999 815 - 866 489
512 1607 1670 2194 1879 1581 - 1659 930
有效带宽 (GB/s) BF16 GQA INT4 GQA
批量大小 FD FA 优化前 CU 优化后 CU FD FA 优化前 CU 优化后 CU
32 965 1012 736 824 262 - 250 380
64 1097 1175 802 972 305 - 278 475
128 1240 968 901 1039 331 - 314 538
256 1301 1100 954 1075 351 - 331 586
512 1338 1287 980 1144 362 - 345 616

在 H100 上

时间 (微秒) BF16 GQA INT4 GQA
批量大小 FD FA 优化前 CU 优化后 CU FD FA 优化前 CU 优化后 CU
32 91 90 114 100 70 - 96 64
64 148 146 200 183 113 - 162 101
128 271 298 361 308 205 - 294 170
256 515 499 658 556 389 - 558 306
512 1000 1011 1260 1066 756 - 1066 575
有效带宽 (GB/s) BF16 GQA INT4 GQA
批量大小 FD FA 优化前 CU 优化后 CU FD FA 优化前 CU 优化后 CU
32 1481 1496 1178 1341 511 - 371 560
64 1815 1840 1345 1470 631 - 443 710
128 1982 1802 1487 1743 699 - 487 844
256 2087 2156 1634 1934 736 - 513 935
512 2150 2127 1706 2015 757 - 537 996

端到端 (E2E) 结果

我们在 8 个 H100 GPU 上评估了 Llama 2 70B 模型中我们优化的 INT4 GQA 内核的端到端性能。我们运行了完整的模型,但仅报告了解码延迟。我们使用 FP8 FFN (前馈网络) 来强调解码阶段的注意力性能。我们将批量大小从 1 变化到 256,上下文长度从 2,048 (2K) 变化到 16,384 (16K)。端到端性能结果如下图所示。

Figure 21: Meta Llama 2 decode latency (ms) comparison

图 21 Meta Llama 2 解码延迟 (毫秒) 比较 (BF16 GQA 在大批量配置下内存不足)

代码

如果您有兴趣,请查看我们的代码 此处。如果您有任何问题,请随时在 GitHub 上开启 issue,我们将很乐意提供帮助。欢迎您的贡献!