作者:Adnan Hoque、Less Wright、Chih-Chieh Yang

摘要

Hopper (H100) GPU 架构号称“首个真正的异步 GPU”,它包含一个新的、完全异步的硬件复制引擎,用于全局内存和共享内存之间的大量数据移动,称为 Tensor Memory Accelerator (TMA)。虽然 CUTLASS 通过其异步流水线范式 内置 了对 TMA 的支持,但 Triton 通过 实验性 API 公开了 TMA 支持。

在这篇文章中,我们将深入探讨 TMA 的工作原理细节,以帮助开发者理解新的异步复制引擎。我们还展示了为 H100 内核利用 TMA 的重要性,方法是在 Triton 中构建一个支持 TMA 的 FP8 GEMM 内核,该内核在小到中等问题规模下,性能比 cuBLAS FP16 提高了 1.4-2.2 倍。最后,我们展示了 Triton 和 CUTLASS 之间可能导致关于 Triton 中 TMA 性能下降报告的关键实现差异。我们将我们的实现开源,以供重现和审查,网址为 https://github.com/pytorch-labs/applied-ai/tree/main/kernels

The throughput in TFLOPs of various Triton and cuBLAS FP8 and FP16 kernels, for M=M, N=4096, K=4096. The red line is the Triton TMA, which showcases the advantages of leveraging TMA.

图 1. 各种 Triton 和 cuBLAS FP8 和 FP16 内核在 M=M、N=4096、K=4096 时的吞吐量(单位:TFLOPs)。红线是 Triton TMA,它展示了利用 TMA 的优势。

TMA 背景

TMA 是 H100 硬件新增功能,允许应用程序在 GPU 全局内存和共享内存之间异步且双向地传输 1D-5D 张量。此外,TMA 还可以将相同的数据不仅传输到调用 SM 的共享内存,还可以传输到同一线程块集群的其他 SM 的共享内存。这被称为“多播”。

TMA 非常轻量级,因为只需要一个线程即可启动 TMA 传输。通过直接将数据从 GMEM(全局内存)移动到 SMEM(共享内存),避免了早期 GPU 在不同内存空间之间移动数据时需要使用寄存器的要求。

A100-style data movement vs H100 with TMA.  TMA hardware eliminates the need for a large amount of threads and registers participating in bulk data transfers.

图 2. A100 风格的数据移动与使用 TMA 的 H100。TMA 硬件消除了大量线程和寄存器参与批量数据传输的需求。(图片来源:Nvidia)

单个线程可以发出大型数据移动指令,允许给定线程块的大部分继续处理其他指令,同时数据正在传输中。结合异步流水线,这使得内存传输可以轻松隐藏,并确保任何给定线程块集群的大部分可以专注于计算任务。

这种轻量级的数据移动调用使得可以创建 warp-group 专用内核,其中 warp-group 承担不同的角色,即生产者和消费者。生产者选择一个领导线程来触发 TMA 请求,然后通过到达屏障与消费者 (MMA) warp-group 异步协调。然后,消费者使用 warp-group MMA 处理数据,并在完成从 SMEM 缓冲区读取数据后向生产者发出信号,并重复该循环。

此外,在线程块集群中,生产者可以降低其最大寄存器需求,因为它们仅发出 TMA 调用,并有效地将额外的寄存器转移给 MMA 消费者,这有助于缓解消费者的寄存器压力。

此外,TMA 还处理共享内存目标地址的计算,数据请求应放置在该地址。这就是为什么调用线程(生产者)可以如此轻量级的原因。

为了确保最大的读取访问速度,TMA 可以根据 swizzling 指令布局到达的数据,以确保消费者可以尽可能快地读取到达的数据,因为 swizzling 模式有助于避免共享内存库冲突。

最后,对于传出或从 SMEM 移动到 GMEM 的 TMA 指令,TMA 还可以包括归约运算(add/min/max)和按位运算(and/or)。

Triton 中 TMA 的使用

Hopper 之前的加载

offs_m = pid_m*block_m + tl.arange(0, block_m)
offs_n = pid_n*block_n + tl.arange(0, block_n)
offs_k = tl.arange(0, block_k)

a_ptrs = a_ptr + (offs_am[:, None]*stride_am + offs_k[None, :]*stride_ak)
b_ptrs = b_ptr + (offs_k[:, None]*stride_bk + offs_bn[None, :]*stride_bn)

a = tl.load(a_ptrs)
b = tl.load(b_ptrs)

图 3. Triton 中从全局内存到共享内存的传统风格批量加载

在上面的 Triton 示例中,显示了 Hopper 之前的加载,我们看到张量 a 和 b 的数据是如何通过每个线程块计算来自其相关 program_id (pid_m, pid_n, k) 的全局偏移量 (a_ptrs, b_ptrs),然后发出请求将内存块移动到 a 和 b 的共享内存中。

现在让我们研究如何在 Triton 中使用 TMA 执行加载。

TMA 指令需要一个特殊的数据结构,称为张量映射,与上面我们直接将指针传递给全局内存的情况相反。为了构建张量映射,我们首先在 CPU 上创建一个 TMA 描述符。描述符通过使用 cuTensorMapEncode API 来处理张量映射的创建。张量映射保存元数据,例如张量的全局内存和共享内存布局,并充当存储在全局内存中的多维张量结构的压缩表示。

TMA address generation via a copy descriptor

图 4. 通过复制描述符的 TMA 地址生成(图片来源:Nvidia)

TMA 描述符保存张量的关键属性

  1. 基指针
  2. 形状和块大小
  3. 数据类型

TMA 描述符在内核之前在主机上创建,然后通过将描述符传递给 torch 张量来移动到设备。因此,在 Triton 中,GEMM 内核接收指向张量映射的全局指针。

Triton 主机代码

   desc_a = np.empty(TMA_SIZE, dtype=np.int8)
   desc_b = np.empty(TMA_SIZE, dtype=np.int8)
   desc_c = np.empty(TMA_SIZE, dtype=np.int8)

   triton.runtime.driver.active.utils.fill_2d_tma_descriptor(a.data_ptr(), m, k, block_m, block_k, a.element_size(), desc_a)

   triton.runtime.driver.active.utils.fill_2d_tma_descriptor(b.data_ptr(), n, k, block_n, block_k, b.element_size(), desc_b)

   triton.runtime.driver.active.utils.fill_2d_tma_descriptor(c.data_ptr(), m, n, block_m, block_n, c.element_size(), desc_c)
  
   desc_a = torch.tensor(desc_a, device='cuda')
   desc_b = torch.tensor(desc_b, device='cuda')
   desc_c = torch.tensor(desc_c, device='cuda')

这是用于在内核调用函数中设置描述符的代码。

Triton 设备代码

偏移量/指针算术

   offs_am = pid_m * block_m
   offs_bn = pid_n * block_n
   offs_k = 0

加载

  a = tl._experimental_descriptor_load(a_desc_ptr, [offs_am, offs_k], [block_m, block_k], tl.float8e4nv)
  b = tl._experimental_descriptor_load(b_desc_ptr, [offs_bn, offs_k], [block_n, block_k], tl.float8e4nv)

存储

 tl._experimental_descriptor_store(c_desc_ptr, accumulator, [offs_am, offs_bn])

我们不再需要在内核中为加载和存储函数计算指针数组。相反,我们传递单个描述符指针、偏移量、块大小和输入数据类型。这简化了地址计算并减少了寄存器压力,因为我们不再需要在软件中进行复杂的指针运算,也不需要为地址计算分配 CUDA 核心。

TMA 性能分析

下面,我们讨论 Hopper 上不同加载机制的 PTX 指令。

加载 Tile 的 PTX (cp.async) - H100 无 TMA

add.s32 	%r27, %r100, %r8;
add.s32 	%r29, %r100, %r9;
selp.b32 	%r30, %r102, 0, %p18;


@%p1 cp.async.cg.shared.global [ %r27 + 0 ], [ %rd20 + 0 ], 0x10, %r30;
@%p1 cp.async.cg.shared.global [ %r29 + 0 ], [ %rd21 + 0 ], 0x10, %r30;


cp.async.commit_group ;

在这里,我们观察到负责全局内存复制的旧 cp.async 指令。从下面的跟踪中,我们可以看到两个加载都绕过了 L1 缓存。较新的 TMA 加载的主要区别在于,在来自 A 和 B 的 tile 准备好被 Tensor Core 消耗之前,我们需要执行一个 ldmatrix 指令,该指令作用于寄存器文件中包含的数据。在 Hopper 上,数据现在可以直接从共享内存中重用。

H100 Memory Chart showing GMEM Throughput = 910.22 GB/s

图 5. H100 内存图显示 GMEM 吞吐量 = 910.22 GB/s(不使用 TMA 的 Triton GEMM),M=128、N=4096、K=4096

通过利用我们上面提到的 Triton API 更改的 TMA,我们可以研究 Triton 为使用 TMA 的单个 2D tile 加载生成的 PTX。

加载 Tile 的 PTX (cp.async.bulk.tensor) - H100 使用 TMA

bar.sync 	0;
shr.u32 	%r5, %r4, 5;
shfl.sync.idx.b32	%r66, %r5, 0, 31, -1;

elect.sync _|%p7, 0xffffffff;


add.s32 	%r24, %r65, %r67;
shl.b32 	%r25, %r66, 7;

@%p8
cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes [%r24], [%rd26, {%r25,%r152}], [%r19];

cp.async.bulk.tensor.2d.shared TMA 指令分别传递共享内存中的目标地址、指向张量映射的指针、张量映射坐标和指向 mbarrier 对象的指针。

H100 Memory Chart GMEM Throughput =1.45 TB/s

图 6. H100 内存图 GMEM 吞吐量 = 1.45 TB/s(使用 TMA 的 Triton GEMM),M=128、N=4096、K=4096

为了获得最佳性能,我们对 TMA GEMM 内核进行了广泛的调优。在其他参数(如 tile 大小、warp 数量和流水线阶段数量)中,当我们将 TMA_SIZE(描述符大小)从 128 增加到 512 时,观察到内存吞吐量增加最多。从上面的 NCU 配置文件中,我们可以看到,最终调优的内核已将全局内存传输吞吐量从 910 GB/s 提高到 1.45 TB/s,与非 TMA Triton GEMM 内核相比,GMEM 吞吐量提高了 59%

CUTLASS 和 Triton FP8 GEMM 及 TMA 实现的比较 - 内核架构

Triton vs CUTLASS Ping-Pong FP8 GEMM TFLOPs, M=M, N=4096, K=4096

图 7. Triton 与 CUTLASS Ping-Pong FP8 GEMM TFLOPs,M=M、N=4096、K=4096

上图显示了 CUTLASS Ping-Pong GEMM 内核 相对于 Triton 的性能。Ping-Pong 内核对 TMA 的利用方式与 Triton 不同。它充分利用了其硬件和软件功能,而 Triton 目前没有这样做。具体而言,CUTLASS 支持以下 TMA 功能,这些功能有助于解释纯 GEMM 性能方面的差距:

  1. TMA 多播

    • 支持将数据从 GMEM 复制到多个 SM
  2. Warp 专用化

    • 支持线程块内的 warp 组承担不同的角色
  3. 张量映射(TMA 描述符)预取

    • 支持从 GMEM 预取张量映射对象,从而允许 TMA 加载的流水线化

为了将性能数字放在上下文中,下面我们显示了一个“加速”图表,重点介绍了延迟差异(百分比表示)

% Speedup of CUTLASS Ping-Pong vs Triton FP8 with TMA.

图 8: CUTLASS Ping-Pong 与使用 TMA 的 Triton FP8 的加速百分比。

此加速纯粹是内核吞吐量,不包括我们将在下面讨论的 E2E 启动开销。

TMA 描述符移动 - Triton 和 CUTLASS 之间的关键差异,具有 E2E 性能影响

如前所述,2D+ 维度 TMA 描述符的创建发生在主机上,然后传输到设备。但是,此传输过程的发生方式因实现而异。

在这里,我们展示了 Triton 传输 TMA 描述符的方式与 CUTLASS 的差异。

回想一下,TMA 传输需要一个特殊的数据结构,即张量映射,该映射需要在 CPU 上通过 cuTensorMap API 创建,对于 FP8 GEMM 内核,这意味着创建三个描述符,分别用于 A、B 和 C。我们在下面看到,对于 Triton 和 CUTLASS 内核,都调用了相同的 CPU 过程。

Calls to cuTensorMapEncodeTiled (Both Triton and CUTLASS use this path)

图 7. 对 cuTensorMapEncodeTiled 的调用(Triton 和 CUTLASS 都使用此路径)

但是,对于 Triton,每个描述符都在其自己独立的复制内核中传输,这增加了大量的开销,并成为在端到端使用推理场景中使用此内核的障碍。

Three H2D Copy Kernels are launched before the kernel execution, for A, B and C

图 8. 在内核执行之前,为 A、B 和 C 启动了三个 H2D 复制内核

由于 TMA 描述符传递到内核的方式,在 CUTLASS 实现中没有观察到这些复制。我们可以从下面的 PTX 中看到,对于 Cutlass,张量映射是通过值传递给内核的。

.entry _ZN7cutlass13device_kernelIN49_GLOBAL__N__8bf0e19b_16_scaled_mm_c3x_cu_2bec3df915cutlass_3x_gemmIaNS_6half_tENS1_14ScaledEpilogueEN4cute5tupleIJNS5_1CILi64EEENS7_ILi128EEES9_EEENS6_IJNS7_ILi2EEENS7_ILi1EEESC_EEENS_4gemm32KernelTmaWarpSpecializedPingpongENS_8epilogue18TmaWarpSpecializedEE10GemmKernelEEEvNT_6ParamsE(

.param .align 64 .b8 _ZN7cutlass13device_kernelIN49_GLOBAL__N__8bf0e19b_16_scaled_mm_c3x_cu_2bec3df915cutlass_3x_gemmIaNS_6half_tENS1_14ScaledEpilogueEN4cute5tupleIJNS5_1CILi64EEENS7_ILi128EEES9_EEENS6_IJNS7_ILi2EEENS7_ILi1EEESC_EEENS_4gemm32KernelTmaWarpSpecializedPingpongENS_8epilogue18TmaWarpSpecializedEE10GemmKernelEEEvNT_6ParamsE_param_0[1024]


mov.b64 	%rd110, _ZN7cutlass13device_kernelIN49_GLOBAL__N__8bf0e19b_16_scaled_mm_c3x_cu_2bec3df915cutlass_3x_gemmIaNS_10bfloat16_tENS1_14ScaledEpilogueEN4cute5tupleIJNS5_1CILi64EEES8_NS7_ILi256EEEEEENS6_IJNS7_ILi1EEESB_SB_EEENS_4gemm24KernelTmaWarpSpecializedENS_8epilogue18TmaWarpSpecializedEE10GemmKernelEEEvNT_6ParamsE_param_0;

add.s64 	%rd70, %rd110, 704;
cvta.param.u64 	%rd69, %rd70;

cp.async.bulk.tensor.2d.global.shared::cta.bulk_group [%rd69, {%r284, %r283}], [%r1880];

图 9. CUTLASS 内核 PTX 显示通过值传递

通过直接传递 TMA 描述符而不是传递全局内存指针,CUTLASS 内核避免了三个额外的 H2D 复制内核,而是将这些复制包含在 GEMM 的单个设备内核启动中。

由于描述符移动到设备的方式存在差异,因此内核延迟(包括准备张量以供 TMA 消耗的时间)差异很大。对于 M=1-128、N=4096、K=4096,CUTLASS pingpong 内核的平均延迟为 10us,而 Triton TMA 内核的平均完成时间为 4ms。这慢了约 3330 倍,并且似乎与 Triton 为 TMA 描述符传输启动的 3 个独立内核直接相关。

Cuda 图可能是减少这种情况的一种方法,但考虑到 H2D 复制造成的开销,当前 Triton 实现(端到端测量时)不具有竞争力。对 Triton 编译器管理 TMA 描述符的方式进行修改可能会解决此差距。因此,我们在上面的数据中专注于比较实际的计算内核吞吐量,而不是 E2E。

结果摘要

Triton FP8 TMA GEMM TFLOPs Comparison

图 10. Triton FP8 TMA GEMM TFLOPs 比较

M Triton TMA Triton 教程 Triton SplitK cuBLAS FP8 cuBLAS FP16 CUTLASS Ping-Pong FP8
1 2.5 1 2.4 1.5 1.8 3.57
2 5.1 2.5 4.8 3.1 3.6 5.9
4 10.3 7.21 9.6 6.1 7.2 14.3
8 21.0 16.5 19.2 12.3 14.4 28.6
16 44.5 41.0 37.2 24.5 27.7 55.1
32 89.7 81.2 72.2 71.6 56.8 114.4
64 178.5 163.7 130.8 144.6 105.3 228.7
128 359.7 225.9 160.1 244.0 189.2 377.7

图 11. Triton FP8 TMA GEMM TFLOPs 比较表

上面的图表和表格总结了我们在单个 NVIDIA H100 上通过利用 TMA 硬件单元,相对于非 TMA Triton 内核和高性能 CUDA (cuBLAS) 内核,在 FP8 GEMM 上实现的增益。需要注意的关键点是此内核相对于竞争对手的卓越的(随批次大小)缩放属性。我们基准测试的问题大小代表了小到中等批次大小 LLM 推理中发现的矩阵形状。因此,对于那些有兴趣利用此内核进行 FP8 LLM 部署用例的人来说,中等 M 范围(M=32 到 M=128)的 TMA GEMM 内核性能将至关重要,因为 FP8 压缩数据类型可以允许更大的矩阵容纳在 GPU 内存中。

为了总结我们的分析,Triton 和 CUTLASS 中的 TMA 实现的不同之处在于完整功能集支持(多播、预取等)以及 TMA 描述符传递到 GPU 内核的方式。如果以更接近 CUTLASS 内核的方式(按值传递)传递此描述符,则可以避免不必要的 H2D 复制,从而大大提高 E2E 性能。

未来工作

对于未来的研究,我们计划通过与社区合作将 CUTLASS TMA 加载架构整合到 Triton 中,以及研究 FP8 GEMM 的协同内核(Ping-Pong 内核的修改策略),来改进这些结果。

此外,一旦在 Triton 中启用线程块集群和 TMA 原子操作等功能,我们或许可以通过利用 TMA GEMM 内核中的 SplitK 策略来获得进一步的加速,因为 Hopper 上的原子操作可以在分布式共享内存 (DSMEM) 中执行,而不是在 L2 缓存中执行。我们还注意到 NVIDIA Hopper GPU 与其他 AI 硬件加速器(如 Google 的 TPU 和 IBM 的 AIU)的相似之处,这些加速器都是数据流架构。在 Hopper 上,由于增加了 TMA(我们在本博客中广泛讨论过)和 DSMEM(我们计划在以后的文章中介绍),数据现在可以从 GMEM“流”到连接的 SM 网络。