作者:Adnan Hoque, Less Wright, Chih-Chieh Yang

摘要

Hopper (H100) GPU 架构被称为“首个真正异步的 GPU”,它包含一个新的、完全异步的硬件复制引擎,用于在全局内存和共享内存之间进行批量数据移动,称为 Tensor Memory Accelerator (TMA)。虽然 CUTLASS 通过其异步流水线范例提供内置的 TMA 支持,但 Triton 则通过一个实验性 API 暴露 TMA 支持。

在本文中,我们深入探究了 TMA 的工作原理细节,以便开发者理解这个新的异步复制引擎。我们还通过在 Triton 中构建一个启用 TMA 的 FP8 GEMM 内核,展示了利用 TMA 对 H100 内核的重要性,该内核对于中小型问题规模,与 cuBLAS FP16 相比,性能提升了 1.4-2.2 倍。最后,我们展示了 Triton 和 CUTLASS 之间关键的实现差异,这些差异可能解释了 Triton 中使用 TMA 出现性能下降的报告。我们开源了我们的实现,以确保可复现性并供查阅:https://github.com/pytorch-labs/applied-ai/tree/main/kernels

The throughput in TFLOPs of various Triton and cuBLAS FP8 and FP16 kernels, for M=M, N=4096, K=4096. The red line is the Triton TMA, which showcases the advantages of leveraging TMA.

图 1. 各种 Triton 和 cuBLAS FP8 和 FP16 内核的 TFLOPs 为单位的吞吐量,M=M, N=4096, K=4096。红线代表 Triton TMA,展示了利用 TMA 的优势。

TMA 背景

TMA 是 H100 的一个硬件新增功能,它允许应用程序在 GPU 全局内存和共享内存之间异步且双向地传输一维到五维 (1D-5D) 的张量。此外,TMA 不仅可以将相同数据传输到调用 SM 的共享内存,还可以传输到同一线程块集群中其他 SM 的共享内存。这被称为“多播”。

TMA 非常轻量级,只需要一个线程即可启动 TMA 传输。通过直接将数据从 GMEM(全局内存)移动到 SMEM(共享内存),这避免了早期 GPU 要求使用寄存器在不同内存空间之间移动数据的情况。

A100-style data movement vs H100 with TMA.  TMA hardware eliminates the need for a large amount of threads and registers participating in bulk data transfers.

图 2. A100 式数据移动与支持 TMA 的 H100 对比。TMA 硬件消除了大量线程和寄存器参与批量数据传输的需求。(图片来源:Nvidia)

单个线程可以发出大型数据移动指令,允许给定线程块中的大部分在数据传输期间继续执行其他指令。结合异步流水线,这使得内存传输易于隐藏,并确保给定线程块集群中的大部分可以专注于计算任务。

这种轻量级的数据移动调用支持创建 warp 组专用内核,其中 warp 组扮演不同角色,即生产者和消费者。生产者选举一个主导线程发出 TMA 请求,这些请求随后通过到达屏障与消费者 (MMA) warp 组异步协调。消费者随后使用 warp 组 MMA 处理数据,并在完成从 SMEM 缓冲区读取后向生产者发出信号,然后循环重复。

此外,在线程块集群内部,生产者由于只负责发出 TMA 调用,可以降低其最大寄存器需求,并有效地将额外寄存器转移给 MMA 消费者,这有助于减轻消费者的寄存器压力。

此外,TMA 处理请求数据应放置的共享内存目标地址的计算。这就是为什么调用线程(生产者)可以如此轻量级。

为了确保最大读取访问速度,TMA 可以根据交织 (swizzling) 指令布置到达的数据,以确保消费者能尽快读取到达的数据,因为交织模式有助于避免共享内存体冲突。

最后,对于出站的 TMA 指令,即从 SMEM 移动数据到 GMEM 的指令,TMA 还可以包含归约操作(加/最小/最大)和位操作(与/或)。

TMA 在 Triton 中的使用

Hopper 之前的加载

offs_m = pid_m*block_m + tl.arange(0, block_m)
offs_n = pid_n*block_n + tl.arange(0, block_n)
offs_k = tl.arange(0, block_k)

a_ptrs = a_ptr + (offs_am[:, None]*stride_am + offs_k[None, :]*stride_ak)
b_ptrs = b_ptr + (offs_k[:, None]*stride_bk + offs_bn[None, :]*stride_bn)

a = tl.load(a_ptrs)
b = tl.load(b_ptrs)

图 3. Triton 中从全局内存到共享内存的传统风格批量加载

在上面展示 Hopper 之前加载的 Triton 示例中,我们看到每个线程块如何通过计算其相关程序 ID (pid_m, pid_n, k) 的全局偏移量 (a_ptrs, b_ptrs),然后发出请求将内存块移动到共享内存中以加载张量 a 和 b 的数据。

现在我们来看看如何在 Triton 中使用 TMA 执行加载。

TMA 指令需要一个特殊的称为 tensor map 的数据结构,这与上面我们直接传递全局内存指针不同。为了构建 tensor map,我们首先在 CPU 上创建一个 TMA descriptor。该 descriptor 通过使用 cuTensorMapEncode API 处理 tensor map 的创建。tensor map 包含元数据,例如张量在全局内存和共享内存中的布局,并作为存储在全局内存中的多维张量结构的压缩表示。

TMA address generation via a copy descriptor

图 4. 通过复制描述符生成 TMA 地址 (图片来源:Nvidia)

TMA descriptor 包含张量的关键属性

  1. 基指针
  2. 形状和块大小
  3. 数据类型

TMA descriptor 在内核执行前在主机上创建,然后通过将 descriptor 传递给 torch 张量移动到设备。因此,在 Triton 中,GEMM 内核接收指向 tensor map 的全局指针。

Triton 主机端代码

   desc_a = np.empty(TMA_SIZE, dtype=np.int8)
   desc_b = np.empty(TMA_SIZE, dtype=np.int8)
   desc_c = np.empty(TMA_SIZE, dtype=np.int8)

   triton.runtime.driver.active.utils.fill_2d_tma_descriptor(a.data_ptr(), m, k, block_m, block_k, a.element_size(), desc_a)

   triton.runtime.driver.active.utils.fill_2d_tma_descriptor(b.data_ptr(), n, k, block_n, block_k, b.element_size(), desc_b)

   triton.runtime.driver.active.utils.fill_2d_tma_descriptor(c.data_ptr(), m, n, block_m, block_n, c.element_size(), desc_c)
  
   desc_a = torch.tensor(desc_a, device='cuda')
   desc_b = torch.tensor(desc_b, device='cuda')
   desc_c = torch.tensor(desc_c, device='cuda')

这是用于在内核调用函数中设置 descriptor 的代码。

Triton 设备端代码

偏移量/指针算术

   offs_am = pid_m * block_m
   offs_bn = pid_n * block_n
   offs_k = 0

加载

  a = tl._experimental_descriptor_load(a_desc_ptr, [offs_am, offs_k], [block_m, block_k], tl.float8e4nv)
  b = tl._experimental_descriptor_load(b_desc_ptr, [offs_bn, offs_k], [block_n, block_k], tl.float8e4nv)

存储

 tl._experimental_descriptor_store(c_desc_ptr, accumulator, [offs_am, offs_bn])

我们不再需要为内核中的加载和存储函数计算指针数组。相反,我们传递一个单一的 descriptor 指针、偏移量、块大小和输入数据类型。这简化了地址计算并降低了寄存器压力,因为我们不再需要在软件中执行复杂的指针算术,也不再需要分配 CUDA 核心用于地址计算。

TMA 性能分析

下面,我们将讨论针对 Hopper 上不同加载机制的 PTX 指令。

加载 Tile 的 PTX (cp.async) - 不使用 TMA 的 H100

add.s32 	%r27, %r100, %r8;
add.s32 	%r29, %r100, %r9;
selp.b32 	%r30, %r102, 0, %p18;


@%p1 cp.async.cg.shared.global [ %r27 + 0 ], [ %rd20 + 0 ], 0x10, %r30;
@%p1 cp.async.cg.shared.global [ %r29 + 0 ], [ %rd21 + 0 ], 0x10, %r30;


cp.async.commit_group ;

这里,我们观察负责全局内存复制的旧的 cp.async 指令。从下面的跟踪记录中我们可以看到,这两种加载都绕过了 L1 缓存。新的 TMA 加载的一个主要区别是,在 Tensor Core 准备好消费来自 A 和 B 的 Tile 之前,我们需要执行 ldmatrix 指令,该指令操作寄存器文件中的数据。在 Hopper 上,数据现在可以直接从共享内存中重复使用。

H100 Memory Chart showing GMEM Throughput = 910.22 GB/s

图 5. 显示 GMEM 吞吐量的 H100 内存图 = 910.22 GB/s (使用 Triton GEMM,不使用 TMA),对于 M=128, N=4096, K=4096

通过利用我们上面提到的 Triton API 更改来使用 TMA,我们可以研究 Triton 为使用 TMA 的单个 2D Tile 加载生成的 PTX。

加载 Tile 的 PTX (cp.async.bulk.tensor) - 使用 TMA 的 H100

bar.sync 	0;
shr.u32 	%r5, %r4, 5;
shfl.sync.idx.b32	%r66, %r5, 0, 31, -1;

elect.sync _|%p7, 0xffffffff;


add.s32 	%r24, %r65, %r67;
shl.b32 	%r25, %r66, 7;

@%p8
cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes [%r24], [%rd26, {%r25,%r152}], [%r19];

cp.async.bulk.tensor.2d.shared TMA 指令被传递了共享内存中的目标地址、指向 tensor map 的指针、tensor map 坐标以及指向 mbarrier 对象的指针。

H100 Memory Chart GMEM Throughput =1.45 TB/s

图 6. 显示 GMEM 吞吐量的 H100 内存图 = 1.45 TB/s (使用 Triton GEMM,使用 TMA),对于 M=128, N=4096, K=4096

为了获得最佳性能,我们广泛地调整了 TMA GEMM 内核。在其他参数如 Tile 大小、warp 数量和流水线阶段数量之外,当我们将 TMA_SIZE (descriptor size) 从 128 增加到 512 时,观察到内存吞吐量增加了最多。从上面的 NCU 分析结果中,我们可以看到最终调优的内核将全局内存传输吞吐量从 910 GB/s 提高到 1.45 TB/s,相较于不使用 TMA 的 Triton GEMM 内核,GMEM 吞吐量提高了 59%

CUTLASS 与 Triton FP8 GEMM 和 TMA 实现对比 - 内核架构

Triton vs CUTLASS Ping-Pong FP8 GEMM TFLOPs, M=M, N=4096, K=4096

图 7. Triton 与 CUTLASS Ping-Pong FP8 GEMM TFLOPs 对比,M=M, N=4096, K=4096

上图显示了 CUTLASS 的 Ping-Pong GEMM 内核 与 Triton 的性能对比。Ping-Pong 内核利用 TMA 的方式与 Triton 不同。它充分利用了其所有的软硬件能力,而 Triton 目前还没有。具体来说,CUTLASS 支持以下 TMA 特性,这些特性有助于解释纯 GEMM 性能中的差距:

  1. TMA 多播

    • 支持数据从 GMEM 复制到多个 SM
  2. Warp 专用化

    • 允许线程块内的 warp 组扮演不同角色
  3. Tensor Map (TMA Descriptor) 预取

    • 支持从 GMEM 预取 Tensor Map 对象,这使得 TMA 加载可以流水线化

为了更直观地展示性能数据,下面我们展示一张“加速”图,突出显示百分比表示的延迟差异

% Speedup of CUTLASS Ping-Pong vs Triton FP8 with TMA.

图 8: CUTLASS Ping-Pong 相较于使用 TMA 的 Triton FP8 的百分比加速。

这种加速纯粹是内核吞吐量,不包括我们将在下文讨论的端到端(E2E)启动开销。

TMA Descriptor 的移动 - Triton 和 CUTLASS 之间的关键差异及其对端到端(E2E)性能的影响

如前所述,2D+ 维 TMA descriptor 的创建发生在主机上,然后被传输到设备。然而,这个传输过程差异很大,取决于实现方式。

这里我们展示了 Triton 传输 TMA descriptor 的方式与 CUTLASS 相比存在的差异。

回想一下,TMA 传输需要一个特殊的数据结构,即 tensor map,它通过 cuTensorMap API 在 CPU 上创建,对于 FP8 GEMM 内核,这意味着创建三个 descriptor,分别对应 A、B 和 C。我们将在下面看到,对于 Triton 和 CUTLASS 内核,都调用了相同的 CPU 过程。

Calls to cuTensorMapEncodeTiled (Both Triton and CUTLASS use this path)

图 7. 对 cuTensorMapEncodeTiled 的调用(Triton 和 CUTLASS 都使用此路径)

然而,对于 Triton,每个 descriptor 都通过其独立的复制内核进行传输,这增加了大量的开销,并成为在端到端推理场景中使用此内核的障碍。

Three H2D Copy Kernels are launched before the kernel execution, for A, B and C

图 8. 在内核执行之前启动了三个 H2D 复制内核,分别用于 A、B 和 C

在 CUTLASS 实现中没有观察到这些复制,因为 TMA descriptor 传递给内核的方式不同。从下面的 PTX 代码中我们可以看到,使用 Cutlass 时,tensor maps 是按值传递给内核的。

.entry _ZN7cutlass13device_kernelIN49_GLOBAL__N__8bf0e19b_16_scaled_mm_c3x_cu_2bec3df915cutlass_3x_gemmIaNS_6half_tENS1_14ScaledEpilogueEN4cute5tupleIJNS5_1CILi64EEENS7_ILi128EEES9_EEENS6_IJNS7_ILi2EEENS7_ILi1EEESC_EEENS_4gemm32KernelTmaWarpSpecializedPingpongENS_8epilogue18TmaWarpSpecializedEE10GemmKernelEEEvNT_6ParamsE(

.param .align 64 .b8 _ZN7cutlass13device_kernelIN49_GLOBAL__N__8bf0e19b_16_scaled_mm_c3x_cu_2bec3df915cutlass_3x_gemmIaNS_6half_tENS1_14ScaledEpilogueEN4cute5tupleIJNS5_1CILi64EEENS7_ILi128EEES9_EEENS6_IJNS7_ILi2EEENS7_ILi1EEESC_EEENS_4gemm32KernelTmaWarpSpecializedPingpongENS_8epilogue18TmaWarpSpecializedEE10GemmKernelEEEvNT_6ParamsE_param_0[1024]


mov.b64 	%rd110, _ZN7cutlass13device_kernelIN49_GLOBAL__N__8bf0e19b_16_scaled_mm_c3x_cu_2bec3df915cutlass_3x_gemmIaNS_10bfloat16_tENS1_14ScaledEpilogueEN4cute5tupleIJNS5_1CILi64EEES8_NS7_ILi256EEEEEENS6_IJNS7_ILi1EEESB_SB_EEENS_4gemm24KernelTmaWarpSpecializedENS_8epilogue18TmaWarpSpecializedEE10GemmKernelEEEvNT_6ParamsE_param_0;

add.s64 	%rd70, %rd110, 704;
cvta.param.u64 	%rd69, %rd70;

cp.async.bulk.tensor.2d.global.shared::cta.bulk_group [%rd69, {%r284, %r283}], [%r1880];

图 9. 显示按值传递的 CUTLASS 内核 PTX

通过直接传递 TMA Descriptor 而不是传递全局内存指针,CUTLASS 内核避免了三个额外的 H2D 复制内核,并且这些复制被包含在 GEMM 的单个设备内核启动中。

由于 descriptor 移动到设备的方式不同,包括准备供 TMA 消费的张量所需时间的内核延迟差异巨大。对于 M=1-128, N=4096, K=4096,CUTLASS pingpong 内核的平均延迟为 10 微秒 (us),而 Triton TMA 内核的平均完成时间为 4 毫秒 (ms)。这慢了约 3330 倍,并且似乎直接与 Triton 为传输 TMA descriptor 而启动的 3 个独立内核有关。

CUDA 图(Cuda graphs)可能是减少此问题的一种方法,但考虑到 H2D 复制造成的开销,当前 Triton 实现在端到端测量时不具竞争力。重构 Triton 编译器管理 TMA descriptor 的方式可能会解决此差距。因此,我们在上述数据中侧重比较了实际的计算内核吞吐量,而非端到端(E2E)性能。

结果总结

Triton FP8 TMA GEMM TFLOPs Comparison

图 10. Triton FP8 TMA GEMM TFLOPs 对比

M Triton TMA Triton Tutorial Triton SplitK cuBLAS FP8 cuBLAS FP16 CUTLASS Ping-Pong FP8
1 2.5 1 2.4 1.5 1.8 3.57
2 5.1 2.5 4.8 3.1 3.6 5.9
4 10.3 7.21 9.6 6.1 7.2 14.3
8 21.0 16.5 19.2 12.3 14.4 28.6
16 44.5 41.0 37.2 24.5 27.7 55.1
32 89.7 81.2 72.2 71.6 56.8 114.4
64 178.5 163.7 130.8 144.6 105.3 228.7
128 359.7 225.9 160.1 244.0 189.2 377.7

图 11. Triton FP8 TMA GEMM TFLOPs 对比表

上述图表总结了我们在单块 NVIDIA H100 上,通过利用 TMA 硬件单元,在 FP8 GEMM 方面取得的增益,相对于不使用 TMA 的 Triton 内核和高性能 CUDA (cuBLAS) 内核。需要注意的关键点是,该内核在扩展性(随批量大小)方面优于竞争对手。我们进行基准测试的问题规模代表了中小型批量大小的 LLM 推理中遇到的矩阵形状。因此,TMA GEMM 内核在中等 M 范围(M=32 到 M=128)的性能将至关重要,对于有兴趣在 FP8 LLM 部署用例中利用此内核的人来说,因为 FP8 压缩数据类型可以使更大的矩阵适合 GPU 内存。

总结我们的分析,Triton 和 CUTLASS 中的 TMA 实现方式在完整功能集支持(多播、预取等)以及 TMA Descriptor 如何传递给 GPU 内核方面存在差异。如果此 descriptor 以更接近 CUTLASS 内核的方式(按值传递)传递,可以避免额外的 H2D 复制,从而大幅提高端到端(E2E)性能。

未来工作

在未来研究中,我们计划改进这些结果,通过与社区合作,将 CUTLASS 的 TMA 加载架构整合到 Triton 中,以及研究用于 FP8 GEMM 的 Cooperative Kernel,这是一种修改后的 Ping-Pong 内核策略。

此外,一旦 Triton 中启用了线程块集群和 TMA 原子操作等功能,我们可能能够通过在 TMA GEMM 内核中利用 SplitK 策略获得进一步的加速,因为 Hopper 上的原子操作可以在分布式共享内存 (DSMEM) 中执行,而不是 L2 缓存。我们还注意到 NVIDIA Hopper GPU 与其他 AI 硬件加速器之间的相似性,例如谷歌的 TPU 和 IBM 的 AIU,它们都是数据流架构。在 Hopper 上,得益于 TMA(我们在本博客中对此进行了广泛讨论)和 DSMEM(我们计划在未来的文章中介绍),数据现在可以从 GMEM“流”向连接的 SM 网络。