摘要
Hopper (H100) GPU 架构被称为“首个真正异步的 GPU”,它包含一个新的、完全异步的硬件复制引擎,用于在全局内存和共享内存之间进行批量数据移动,称为 Tensor Memory Accelerator (TMA)。虽然 CUTLASS 通过其异步流水线范例提供内置的 TMA 支持,但 Triton 则通过一个实验性 API 暴露 TMA 支持。
在本文中,我们深入探究了 TMA 的工作原理细节,以便开发者理解这个新的异步复制引擎。我们还通过在 Triton 中构建一个启用 TMA 的 FP8 GEMM 内核,展示了利用 TMA 对 H100 内核的重要性,该内核对于中小型问题规模,与 cuBLAS FP16 相比,性能提升了 1.4-2.2 倍。最后,我们展示了 Triton 和 CUTLASS 之间关键的实现差异,这些差异可能解释了 Triton 中使用 TMA 出现性能下降的报告。我们开源了我们的实现,以确保可复现性并供查阅:https://github.com/pytorch-labs/applied-ai/tree/main/kernels
图 1. 各种 Triton 和 cuBLAS FP8 和 FP16 内核的 TFLOPs 为单位的吞吐量,M=M, N=4096, K=4096。红线代表 Triton TMA,展示了利用 TMA 的优势。
TMA 背景
TMA 是 H100 的一个硬件新增功能,它允许应用程序在 GPU 全局内存和共享内存之间异步且双向地传输一维到五维 (1D-5D) 的张量。此外,TMA 不仅可以将相同数据传输到调用 SM 的共享内存,还可以传输到同一线程块集群中其他 SM 的共享内存。这被称为“多播”。
TMA 非常轻量级,只需要一个线程即可启动 TMA 传输。通过直接将数据从 GMEM(全局内存)移动到 SMEM(共享内存),这避免了早期 GPU 要求使用寄存器在不同内存空间之间移动数据的情况。
图 2. A100 式数据移动与支持 TMA 的 H100 对比。TMA 硬件消除了大量线程和寄存器参与批量数据传输的需求。(图片来源:Nvidia)
单个线程可以发出大型数据移动指令,允许给定线程块中的大部分在数据传输期间继续执行其他指令。结合异步流水线,这使得内存传输易于隐藏,并确保给定线程块集群中的大部分可以专注于计算任务。
这种轻量级的数据移动调用支持创建 warp 组专用内核,其中 warp 组扮演不同角色,即生产者和消费者。生产者选举一个主导线程发出 TMA 请求,这些请求随后通过到达屏障与消费者 (MMA) warp 组异步协调。消费者随后使用 warp 组 MMA 处理数据,并在完成从 SMEM 缓冲区读取后向生产者发出信号,然后循环重复。
此外,在线程块集群内部,生产者由于只负责发出 TMA 调用,可以降低其最大寄存器需求,并有效地将额外寄存器转移给 MMA 消费者,这有助于减轻消费者的寄存器压力。
此外,TMA 处理请求数据应放置的共享内存目标地址的计算。这就是为什么调用线程(生产者)可以如此轻量级。
为了确保最大读取访问速度,TMA 可以根据交织 (swizzling) 指令布置到达的数据,以确保消费者能尽快读取到达的数据,因为交织模式有助于避免共享内存体冲突。
最后,对于出站的 TMA 指令,即从 SMEM 移动数据到 GMEM 的指令,TMA 还可以包含归约操作(加/最小/最大)和位操作(与/或)。
TMA 在 Triton 中的使用
Hopper 之前的加载
offs_m = pid_m*block_m + tl.arange(0, block_m)
offs_n = pid_n*block_n + tl.arange(0, block_n)
offs_k = tl.arange(0, block_k)
a_ptrs = a_ptr + (offs_am[:, None]*stride_am + offs_k[None, :]*stride_ak)
b_ptrs = b_ptr + (offs_k[:, None]*stride_bk + offs_bn[None, :]*stride_bn)
a = tl.load(a_ptrs)
b = tl.load(b_ptrs)
图 3. Triton 中从全局内存到共享内存的传统风格批量加载
在上面展示 Hopper 之前加载的 Triton 示例中,我们看到每个线程块如何通过计算其相关程序 ID (pid_m, pid_n, k) 的全局偏移量 (a_ptrs, b_ptrs),然后发出请求将内存块移动到共享内存中以加载张量 a 和 b 的数据。
现在我们来看看如何在 Triton 中使用 TMA 执行加载。
TMA 指令需要一个特殊的称为 tensor map 的数据结构,这与上面我们直接传递全局内存指针不同。为了构建 tensor map,我们首先在 CPU 上创建一个 TMA descriptor。该 descriptor 通过使用 cuTensorMapEncode API 处理 tensor map 的创建。tensor map 包含元数据,例如张量在全局内存和共享内存中的布局,并作为存储在全局内存中的多维张量结构的压缩表示。
图 4. 通过复制描述符生成 TMA 地址 (图片来源:Nvidia)
TMA descriptor 包含张量的关键属性
- 基指针
- 形状和块大小
- 数据类型
TMA descriptor 在内核执行前在主机上创建,然后通过将 descriptor 传递给 torch 张量移动到设备。因此,在 Triton 中,GEMM 内核接收指向 tensor map 的全局指针。
Triton 主机端代码
desc_a = np.empty(TMA_SIZE, dtype=np.int8)
desc_b = np.empty(TMA_SIZE, dtype=np.int8)
desc_c = np.empty(TMA_SIZE, dtype=np.int8)
triton.runtime.driver.active.utils.fill_2d_tma_descriptor(a.data_ptr(), m, k, block_m, block_k, a.element_size(), desc_a)
triton.runtime.driver.active.utils.fill_2d_tma_descriptor(b.data_ptr(), n, k, block_n, block_k, b.element_size(), desc_b)
triton.runtime.driver.active.utils.fill_2d_tma_descriptor(c.data_ptr(), m, n, block_m, block_n, c.element_size(), desc_c)
desc_a = torch.tensor(desc_a, device='cuda')
desc_b = torch.tensor(desc_b, device='cuda')
desc_c = torch.tensor(desc_c, device='cuda')
这是用于在内核调用函数中设置 descriptor 的代码。
Triton 设备端代码
偏移量/指针算术
offs_am = pid_m * block_m
offs_bn = pid_n * block_n
offs_k = 0
加载
a = tl._experimental_descriptor_load(a_desc_ptr, [offs_am, offs_k], [block_m, block_k], tl.float8e4nv)
b = tl._experimental_descriptor_load(b_desc_ptr, [offs_bn, offs_k], [block_n, block_k], tl.float8e4nv)
存储
tl._experimental_descriptor_store(c_desc_ptr, accumulator, [offs_am, offs_bn])
我们不再需要为内核中的加载和存储函数计算指针数组。相反,我们传递一个单一的 descriptor 指针、偏移量、块大小和输入数据类型。这简化了地址计算并降低了寄存器压力,因为我们不再需要在软件中执行复杂的指针算术,也不再需要分配 CUDA 核心用于地址计算。
TMA 性能分析
下面,我们将讨论针对 Hopper 上不同加载机制的 PTX 指令。
加载 Tile 的 PTX (cp.async) - 不使用 TMA 的 H100
add.s32 %r27, %r100, %r8;
add.s32 %r29, %r100, %r9;
selp.b32 %r30, %r102, 0, %p18;
@%p1 cp.async.cg.shared.global [ %r27 + 0 ], [ %rd20 + 0 ], 0x10, %r30;
@%p1 cp.async.cg.shared.global [ %r29 + 0 ], [ %rd21 + 0 ], 0x10, %r30;
cp.async.commit_group ;
这里,我们观察负责全局内存复制的旧的 cp.async 指令。从下面的跟踪记录中我们可以看到,这两种加载都绕过了 L1 缓存。新的 TMA 加载的一个主要区别是,在 Tensor Core 准备好消费来自 A 和 B 的 Tile 之前,我们需要执行 ldmatrix 指令,该指令操作寄存器文件中的数据。在 Hopper 上,数据现在可以直接从共享内存中重复使用。
图 5. 显示 GMEM 吞吐量的 H100 内存图 = 910.22 GB/s (使用 Triton GEMM,不使用 TMA),对于 M=128, N=4096, K=4096
通过利用我们上面提到的 Triton API 更改来使用 TMA,我们可以研究 Triton 为使用 TMA 的单个 2D Tile 加载生成的 PTX。
加载 Tile 的 PTX (cp.async.bulk.tensor) - 使用 TMA 的 H100
bar.sync 0;
shr.u32 %r5, %r4, 5;
shfl.sync.idx.b32 %r66, %r5, 0, 31, -1;
elect.sync _|%p7, 0xffffffff;
add.s32 %r24, %r65, %r67;
shl.b32 %r25, %r66, 7;
@%p8
cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes [%r24], [%rd26, {%r25,%r152}], [%r19];
cp.async.bulk.tensor.2d.shared TMA 指令被传递了共享内存中的目标地址、指向 tensor map 的指针、tensor map 坐标以及指向 mbarrier 对象的指针。
图 6. 显示 GMEM 吞吐量的 H100 内存图 = 1.45 TB/s (使用 Triton GEMM,使用 TMA),对于 M=128, N=4096, K=4096
为了获得最佳性能,我们广泛地调整了 TMA GEMM 内核。在其他参数如 Tile 大小、warp 数量和流水线阶段数量之外,当我们将 TMA_SIZE (descriptor size) 从 128 增加到 512 时,观察到内存吞吐量增加了最多。从上面的 NCU 分析结果中,我们可以看到最终调优的内核将全局内存传输吞吐量从 910 GB/s 提高到 1.45 TB/s,相较于不使用 TMA 的 Triton GEMM 内核,GMEM 吞吐量提高了 59%。
CUTLASS 与 Triton FP8 GEMM 和 TMA 实现对比 - 内核架构
图 7. Triton 与 CUTLASS Ping-Pong FP8 GEMM TFLOPs 对比,M=M, N=4096, K=4096
上图显示了 CUTLASS 的 Ping-Pong GEMM 内核 与 Triton 的性能对比。Ping-Pong 内核利用 TMA 的方式与 Triton 不同。它充分利用了其所有的软硬件能力,而 Triton 目前还没有。具体来说,CUTLASS 支持以下 TMA 特性,这些特性有助于解释纯 GEMM 性能中的差距:
-
TMA 多播
- 支持数据从 GMEM 复制到多个 SM
-
Warp 专用化
- 允许线程块内的 warp 组扮演不同角色
-
Tensor Map (TMA Descriptor) 预取
- 支持从 GMEM 预取 Tensor Map 对象,这使得 TMA 加载可以流水线化
为了更直观地展示性能数据,下面我们展示一张“加速”图,突出显示百分比表示的延迟差异
图 8: CUTLASS Ping-Pong 相较于使用 TMA 的 Triton FP8 的百分比加速。
这种加速纯粹是内核吞吐量,不包括我们将在下文讨论的端到端(E2E)启动开销。
TMA Descriptor 的移动 - Triton 和 CUTLASS 之间的关键差异及其对端到端(E2E)性能的影响
如前所述,2D+ 维 TMA descriptor 的创建发生在主机上,然后被传输到设备。然而,这个传输过程差异很大,取决于实现方式。
这里我们展示了 Triton 传输 TMA descriptor 的方式与 CUTLASS 相比存在的差异。
回想一下,TMA 传输需要一个特殊的数据结构,即 tensor map,它通过 cuTensorMap API 在 CPU 上创建,对于 FP8 GEMM 内核,这意味着创建三个 descriptor,分别对应 A、B 和 C。我们将在下面看到,对于 Triton 和 CUTLASS 内核,都调用了相同的 CPU 过程。
图 7. 对 cuTensorMapEncodeTiled 的调用(Triton 和 CUTLASS 都使用此路径)
然而,对于 Triton,每个 descriptor 都通过其独立的复制内核进行传输,这增加了大量的开销,并成为在端到端推理场景中使用此内核的障碍。
图 8. 在内核执行之前启动了三个 H2D 复制内核,分别用于 A、B 和 C
在 CUTLASS 实现中没有观察到这些复制,因为 TMA descriptor 传递给内核的方式不同。从下面的 PTX 代码中我们可以看到,使用 Cutlass 时,tensor maps 是按值传递给内核的。
.entry _ZN7cutlass13device_kernelIN49_GLOBAL__N__8bf0e19b_16_scaled_mm_c3x_cu_2bec3df915cutlass_3x_gemmIaNS_6half_tENS1_14ScaledEpilogueEN4cute5tupleIJNS5_1CILi64EEENS7_ILi128EEES9_EEENS6_IJNS7_ILi2EEENS7_ILi1EEESC_EEENS_4gemm32KernelTmaWarpSpecializedPingpongENS_8epilogue18TmaWarpSpecializedEE10GemmKernelEEEvNT_6ParamsE(
.param .align 64 .b8 _ZN7cutlass13device_kernelIN49_GLOBAL__N__8bf0e19b_16_scaled_mm_c3x_cu_2bec3df915cutlass_3x_gemmIaNS_6half_tENS1_14ScaledEpilogueEN4cute5tupleIJNS5_1CILi64EEENS7_ILi128EEES9_EEENS6_IJNS7_ILi2EEENS7_ILi1EEESC_EEENS_4gemm32KernelTmaWarpSpecializedPingpongENS_8epilogue18TmaWarpSpecializedEE10GemmKernelEEEvNT_6ParamsE_param_0[1024]
mov.b64 %rd110, _ZN7cutlass13device_kernelIN49_GLOBAL__N__8bf0e19b_16_scaled_mm_c3x_cu_2bec3df915cutlass_3x_gemmIaNS_10bfloat16_tENS1_14ScaledEpilogueEN4cute5tupleIJNS5_1CILi64EEES8_NS7_ILi256EEEEEENS6_IJNS7_ILi1EEESB_SB_EEENS_4gemm24KernelTmaWarpSpecializedENS_8epilogue18TmaWarpSpecializedEE10GemmKernelEEEvNT_6ParamsE_param_0;
add.s64 %rd70, %rd110, 704;
cvta.param.u64 %rd69, %rd70;
cp.async.bulk.tensor.2d.global.shared::cta.bulk_group [%rd69, {%r284, %r283}], [%r1880];
图 9. 显示按值传递的 CUTLASS 内核 PTX
通过直接传递 TMA Descriptor 而不是传递全局内存指针,CUTLASS 内核避免了三个额外的 H2D 复制内核,并且这些复制被包含在 GEMM 的单个设备内核启动中。
由于 descriptor 移动到设备的方式不同,包括准备供 TMA 消费的张量所需时间的内核延迟差异巨大。对于 M=1-128, N=4096, K=4096,CUTLASS pingpong 内核的平均延迟为 10 微秒 (us),而 Triton TMA 内核的平均完成时间为 4 毫秒 (ms)。这慢了约 3330 倍,并且似乎直接与 Triton 为传输 TMA descriptor 而启动的 3 个独立内核有关。
CUDA 图(Cuda graphs)可能是减少此问题的一种方法,但考虑到 H2D 复制造成的开销,当前 Triton 实现在端到端测量时不具竞争力。重构 Triton 编译器管理 TMA descriptor 的方式可能会解决此差距。因此,我们在上述数据中侧重比较了实际的计算内核吞吐量,而非端到端(E2E)性能。
结果总结
图 10. Triton FP8 TMA GEMM TFLOPs 对比
M | Triton TMA | Triton Tutorial | Triton SplitK | cuBLAS FP8 | cuBLAS FP16 | CUTLASS Ping-Pong FP8 |
1 | 2.5 | 1 | 2.4 | 1.5 | 1.8 | 3.57 |
2 | 5.1 | 2.5 | 4.8 | 3.1 | 3.6 | 5.9 |
4 | 10.3 | 7.21 | 9.6 | 6.1 | 7.2 | 14.3 |
8 | 21.0 | 16.5 | 19.2 | 12.3 | 14.4 | 28.6 |
16 | 44.5 | 41.0 | 37.2 | 24.5 | 27.7 | 55.1 |
32 | 89.7 | 81.2 | 72.2 | 71.6 | 56.8 | 114.4 |
64 | 178.5 | 163.7 | 130.8 | 144.6 | 105.3 | 228.7 |
128 | 359.7 | 225.9 | 160.1 | 244.0 | 189.2 | 377.7 |
图 11. Triton FP8 TMA GEMM TFLOPs 对比表
上述图表总结了我们在单块 NVIDIA H100 上,通过利用 TMA 硬件单元,在 FP8 GEMM 方面取得的增益,相对于不使用 TMA 的 Triton 内核和高性能 CUDA (cuBLAS) 内核。需要注意的关键点是,该内核在扩展性(随批量大小)方面优于竞争对手。我们进行基准测试的问题规模代表了中小型批量大小的 LLM 推理中遇到的矩阵形状。因此,TMA GEMM 内核在中等 M 范围(M=32 到 M=128)的性能将至关重要,对于有兴趣在 FP8 LLM 部署用例中利用此内核的人来说,因为 FP8 压缩数据类型可以使更大的矩阵适合 GPU 内存。
总结我们的分析,Triton 和 CUTLASS 中的 TMA 实现方式在完整功能集支持(多播、预取等)以及 TMA Descriptor 如何传递给 GPU 内核方面存在差异。如果此 descriptor 以更接近 CUTLASS 内核的方式(按值传递)传递,可以避免额外的 H2D 复制,从而大幅提高端到端(E2E)性能。
未来工作
在未来研究中,我们计划改进这些结果,通过与社区合作,将 CUTLASS 的 TMA 加载架构整合到 Triton 中,以及研究用于 FP8 GEMM 的 Cooperative Kernel,这是一种修改后的 Ping-Pong 内核策略。
此外,一旦 Triton 中启用了线程块集群和 TMA 原子操作等功能,我们可能能够通过在 TMA GEMM 内核中利用 SplitK 策略获得进一步的加速,因为 Hopper 上的原子操作可以在分布式共享内存 (DSMEM) 中执行,而不是 L2 缓存。我们还注意到 NVIDIA Hopper GPU 与其他 AI 硬件加速器之间的相似性,例如谷歌的 TPU 和 IBM 的 AIU,它们都是数据流架构。在 Hopper 上,得益于 TMA(我们在本博客中对此进行了广泛讨论)和 DSMEM(我们计划在未来的文章中介绍),数据现在可以从 GMEM“流”向连接的 SM 网络。