跳转到主要内容
博客

深入探讨 Hopper TMA 单元在 FP8 GEMM 中的应用

作者: 2024 年 7 月 22 日2024 年 11 月 12 日无评论

摘要

Hopper (H100) GPU 架构被称为“第一个真正异步的 GPU”,它包含了一个全新的、完全异步的硬件复制引擎,用于全局内存和共享内存之间的批量数据传输,称为 Tensor Memory Accelerator (TMA)。虽然 CUTLASS 通过其异步管道范式内置支持 TMA,但 Triton 通过一个实验性 API 暴露了 TMA 支持。

在这篇文章中,我们将深入探讨 TMA 的工作原理,以帮助开发者理解这个新的异步复制引擎。我们还将通过在 Triton 中构建一个支持 TMA 的 FP8 GEMM 内核,展示利用 TMA 对 H100 内核的重要性,该内核对于中小型问题规模比 cuBLAS FP16 提供了 1.4-2.2 倍的性能提升。最后,我们将展示 Triton 和 CUTLASS 之间主要的实现差异,这些差异可能解释了 Triton 中 TMA 性能下降的报告。我们开源了我们的实现,以便在https://github.com/pytorch-labs/applied-ai/tree/main/kernels 上进行复现和审查

The throughput in TFLOPs of various Triton and cuBLAS FP8 and FP16 kernels, for M=M, N=4096, K=4096. The red line is the Triton TMA, which showcases the advantages of leveraging TMA.

图 1. 各 Triton 和 cuBLAS FP8 和 FP16 内核的吞吐量(TFLOPs),M=M,N=4096,K=4096。红线表示 Triton TMA,展示了利用 TMA 的优势。

TMA 背景

TMA 是 H100 的一个硬件新增功能,它允许应用程序在 GPU 全局内存和共享内存之间异步双向传输 1D-5D 张量。此外,TMA 不仅可以将相同的数据传输到调用 SM 的共享内存,还可以传输到同一线程块集群中的其他 SM 的共享内存。这被称为“多播”。

TMA 非常轻量级,只需一个线程即可启动 TMA 传输。通过将数据直接从 GMEM(全局内存)移动到 SMEM(共享内存),这避免了早期 GPU 在不同内存空间之间移动数据时使用寄存器的要求。

A100-style data movement vs H100 with TMA.  TMA hardware eliminates the need for a large amount of threads and registers participating in bulk data transfers.

图 2. A100 风格数据移动与 H100 与 TMA。TMA 硬件消除了大量线程和寄存器参与批量数据传输的需要。(图片来源:Nvidia)

单个线程可以发出大型数据移动指令,允许给定线程块的大部分在数据传输进行中继续处理其他指令。结合异步流水线,这使得内存传输可以轻松隐藏,并确保给定线程块集群的大部分可以专注于计算任务。

这种轻量级的数据移动调用方式使得创建 warp-group 专用内核成为可能,其中 warp-group 承担不同的角色,即生产者和消费者。生产者选择一个领导线程来触发 TMA 请求,然后通过到达屏障与消费者(MMA)warp-group 进行异步协调。消费者随后使用 warp-group MMA 处理数据,并在从 SMEM 缓冲区读取完成后向生产者发出信号,然后循环重复。

此外,在线程块集群内部,生产者可以降低其最大寄存器要求,因为它们只发出 TMA 调用,从而有效地将额外的寄存器转移给 MMA 消费者,这有助于缓解消费者的寄存器压力。

此外,TMA 处理了请求数据应放置的共享内存目标的地址计算。这就是为什么调用线程(生产者)可以如此轻量级。

为了确保最大的读取访问速度,TMA 可以根据交织指令布局到达的数据,以确保消费者能够以最快的速度读取到达的数据,因为交织模式有助于避免共享内存bank冲突。

最后,对于传出或将数据从 SMEM 移动到 GMEM 的 TMA 指令,TMA 还可以包括归约操作(add/min/max)和按位操作(and/or)。

TMA 在 Triton 中的使用

Hopper 之前的加载

offs_m = pid_m*block_m + tl.arange(0, block_m)
offs_n = pid_n*block_n + tl.arange(0, block_n)
offs_k = tl.arange(0, block_k)

a_ptrs = a_ptr + (offs_am[:, None]*stride_am + offs_k[None, :]*stride_ak)
b_ptrs = b_ptr + (offs_k[:, None]*stride_bk + offs_bn[None, :]*stride_bn)

a = tl.load(a_ptrs)
b = tl.load(b_ptrs)

图 3. Triton 中从全局内存到共享内存的传统批量加载方式

在上面展示 Hopper 之前加载的 Triton 示例中,我们看到张量 a 和 b 的数据是如何由每个线程块通过计算其相关 program_id (pid_m, pid_n, k) 的全局偏移量 (a_ptrs, b_ptrs),然后发出请求将内存块移动到 a 和 b 的共享内存中来加载的。

现在让我们来看看如何在 Triton 中使用 TMA 执行加载。

TMA 指令需要一个称为张量图的特殊数据结构,与上面我们直接传递指向全局内存的指针不同。要构建张量图,我们首先在 CPU 上创建一个 TMA 描述符。描述符通过使用 cuTensorMapEncode API 处理张量图的创建。张量图包含元数据,例如张量的全局和共享内存布局,并作为存储在全局内存中的多维张量结构的压缩表示。

TMA address generation via a copy descriptor

图 4. 通过复制描述符生成 TMA 地址(图片来源:Nvidia)

TMA 描述符包含张量的关键属性:

  1. 基指针
  2. 形状和块大小
  3. 数据类型

TMA 描述符在内核之前在主机上创建,然后通过将描述符传递给 torch 张量来移动到设备。因此,在 Triton 中,GEMM 内核接收指向张量图的全局指针。

Triton 主机代码

   desc_a = np.empty(TMA_SIZE, dtype=np.int8)
   desc_b = np.empty(TMA_SIZE, dtype=np.int8)
   desc_c = np.empty(TMA_SIZE, dtype=np.int8)

   triton.runtime.driver.active.utils.fill_2d_tma_descriptor(a.data_ptr(), m, k, block_m, block_k, a.element_size(), desc_a)

   triton.runtime.driver.active.utils.fill_2d_tma_descriptor(b.data_ptr(), n, k, block_n, block_k, b.element_size(), desc_b)

   triton.runtime.driver.active.utils.fill_2d_tma_descriptor(c.data_ptr(), m, n, block_m, block_n, c.element_size(), desc_c)
  
   desc_a = torch.tensor(desc_a, device='cuda')
   desc_b = torch.tensor(desc_b, device='cuda')
   desc_c = torch.tensor(desc_c, device='cuda')

这是用于在内核调用函数中设置描述符的代码。

Triton 设备代码

偏移量/指针算术

   offs_am = pid_m * block_m
   offs_bn = pid_n * block_n
   offs_k = 0

加载

  a = tl._experimental_descriptor_load(a_desc_ptr, [offs_am, offs_k], [block_m, block_k], tl.float8e4nv)
  b = tl._experimental_descriptor_load(b_desc_ptr, [offs_bn, offs_k], [block_n, block_k], tl.float8e4nv)

存储

 tl._experimental_descriptor_store(c_desc_ptr, accumulator, [offs_am, offs_bn])

我们不再需要在内核中为加载和存储函数计算指针数组。相反,我们传递一个描述符指针、偏移量、块大小和输入数据类型。这简化了地址计算并降低了寄存器压力,因为我们不再需要在软件中进行复杂的指针运算并为地址计算分配 CUDA 核心。

TMA 性能分析

下面,我们讨论 Hopper 上不同加载机制的 PTX 指令。

用于加载 Tile 的 PTX (cp.async) – H100 无 TMA

add.s32 	%r27, %r100, %r8;
add.s32 	%r29, %r100, %r9;
selp.b32 	%r30, %r102, 0, %p18;


@%p1 cp.async.cg.shared.global [ %r27 + 0 ], [ %rd20 + 0 ], 0x10, %r30;
@%p1 cp.async.cg.shared.global [ %r29 + 0 ], [ %rd21 + 0 ], 0x10, %r30;


cp.async.commit_group ;

在这里,我们观察到负责全局内存复制的旧的 cp.async 指令。从下面的跟踪中我们可以看到,两次加载都绕过了 L1 缓存。新的 TMA 加载的一个主要区别是,在 A 和 B 的瓦片准备好被 Tensor Core 消耗之前,我们需要执行一个 ldmatrix 指令,该指令操作寄存器文件中的数据。在 Hopper 上,数据现在可以直接从共享内存中重用。

H100 Memory Chart showing GMEM Throughput = 910.22 GB/s

图 5. H100 内存图显示 GMEM 吞吐量 = 910.22 GB/s(Triton GEMM 不带 TMA),M=128,N=4096,K=4096

通过利用我们上面提到的 Triton API 更改中的 TMA,我们可以研究 Triton 为单个 2D 瓦片加载与 TMA 生成的 PTX。

用于加载 Tile 的 PTX (cp.async.bulk.tensor) – H100 使用 TMA

bar.sync 	0;
shr.u32 	%r5, %r4, 5;
shfl.sync.idx.b32	%r66, %r5, 0, 31, -1;

elect.sync _|%p7, 0xffffffff;


add.s32 	%r24, %r65, %r67;
shl.b32 	%r25, %r66, 7;

@%p8
cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes [%r24], [%rd26, {%r25,%r152}], [%r19];

cp.async.bulk.tensor.2d.shared TMA 指令依次传递共享内存中的目标地址、指向张量映射的指针、张量映射坐标和指向 mbarrier 对象的指针。

H100 Memory Chart GMEM Throughput =1.45 TB/s

图 6. H100 内存图 GMEM 吞吐量 = 1.45 TB/s(Triton GEMM TMA),M=128,N=4096,K=4096

为了获得最佳性能,我们对 TMA GEMM 内核进行了广泛的调优。在其他参数如瓦片大小、warp 数量和流水线阶段数量中,当我们将 TMA_SIZE(描述符大小)从 128 增加到 512 时,观察到内存吞吐量增加最大。从上面的 NCU 配置文件中,我们可以看到,最终调优后的内核将全局内存传输吞吐量从 910 GB/s 提高到 1.45 TB/s,比非 TMA Triton GEMM 内核增加了 59% 的 GMEM 吞吐量。

CUTLASS 和 Triton FP8 GEMM 及 TMA 实现的比较 – 内核架构

Triton vs CUTLASS Ping-Pong FP8 GEMM TFLOPs, M=M, N=4096, K=4096

图 7. Triton 与 CUTLASS Ping-Pong FP8 GEMM TFLOPs 对比,M=M,N=4096,K=4096

上图显示了 CUTLASS Ping-Pong GEMM 内核与 Triton 的性能。Ping-Pong 内核与 Triton 不同地利用 TMA。它利用了其所有的硬件和软件功能,而 Triton 目前则没有。具体来说,CUTLASS 支持以下 TMA 功能,这些功能有助于解释纯 GEMM 性能中的差距:。

  1. TMA 多播
    • 支持将数据从 GMEM 复制到多个 SM
  2. Warp 专用化
    • 使线程块内的 warp 组承担不同的角色
  3. 张量图(TMA 描述符)预取
    • 支持从 GMEM 预取张量图对象,从而允许 TMA 加载的流水线操作

为了更好地理解这些性能数据,下面我们展示了一个“加速”图,以百分比形式突出显示延迟差异

% Speedup of CUTLASS Ping-Pong vs Triton FP8 with TMA.

图 8: CUTLASS Ping-Pong 相较于带 TMA 的 Triton FP8 的加速百分比。

这种加速纯粹是内核吞吐量,不包括我们将讨论的 E2E 启动开销。

TMA 描述符移动——Triton 和 CUTLASS 之间的一个关键差异,对 E2E 性能有影响

如前所述,2D+ 维 TMA 描述符的创建在主机上进行,然后传输到设备。然而,这种传输过程根据实现的不同而有很大差异。

在这里,我们展示了 Triton 和 CUTLASS 在 TMA 描述符传输方式上的差异。

回想一下,TMA 传输需要通过 cuTensorMap API 在 CPU 上创建一个特殊的张量图数据结构,对于 FP8 GEMM 内核来说,这意味着要为 A、B 和 C 各创建一个描述符。我们将在下面看到,对于 Triton 和 CUTLASS 内核,都调用了相同的 CPU 过程。

Calls to cuTensorMapEncodeTiled (Both Triton and CUTLASS use this path)

图 7. 对 cuTensorMapEncodeTiled 的调用(Triton 和 CUTLASS 都使用此路径)

然而,对于 Triton,每个描述符都在其独立的复制内核中传输,这增加了大量的开销,并成为在端到端推理场景中使用此内核的障碍。

Three H2D Copy Kernels are launched before the kernel execution, for A, B and C

图 8. 在内核执行之前,启动了三个 H2D 复制内核,分别用于 A、B 和 C

在 CUTLASS 实现中没有观察到这些副本,这是由于 TMA 描述符传递给内核的方式。从下面的 PTX 中我们可以看到,使用 Cutlass,张量映射以传值方式传递给内核。

.entry _ZN7cutlass13device_kernelIN49_GLOBAL__N__8bf0e19b_16_scaled_mm_c3x_cu_2bec3df915cutlass_3x_gemmIaNS_6half_tENS1_14ScaledEpilogueEN4cute5tupleIJNS5_1CILi64EEENS7_ILi128EEES9_EEENS6_IJNS7_ILi2EEENS7_ILi1EEESC_EEENS_4gemm32KernelTmaWarpSpecializedPingpongENS_8epilogue18TmaWarpSpecializedEE10GemmKernelEEEvNT_6ParamsE(

.param .align 64 .b8 _ZN7cutlass13device_kernelIN49_GLOBAL__N__8bf0e19b_16_scaled_mm_c3x_cu_2bec3df915cutlass_3x_gemmIaNS_6half_tENS1_14ScaledEpilogueEN4cute5tupleIJNS5_1CILi64EEENS7_ILi128EEES9_EEENS6_IJNS7_ILi2EEENS7_ILi1EEESC_EEENS_4gemm32KernelTmaWarpSpecializedPingpongENS_8epilogue18TmaWarpSpecializedEE10GemmKernelEEEvNT_6ParamsE_param_0[1024]


mov.b64 	%rd110, _ZN7cutlass13device_kernelIN49_GLOBAL__N__8bf0e19b_16_scaled_mm_c3x_cu_2bec3df915cutlass_3x_gemmIaNS_10bfloat16_tENS1_14ScaledEpilogueEN4cute5tupleIJNS5_1CILi64EEES8_NS7_ILi256EEEEEENS6_IJNS7_ILi1EEESB_SB_EEENS_4gemm24KernelTmaWarpSpecializedENS_8epilogue18TmaWarpSpecializedEE10GemmKernelEEEvNT_6ParamsE_param_0;

add.s64 	%rd70, %rd110, 704;
cvta.param.u64 	%rd69, %rd70;

cp.async.bulk.tensor.2d.global.shared::cta.bulk_group [%rd69, {%r284, %r283}], [%r1880];

图 9. CUTLASS 内核 PTX 显示传值方式

通过直接传递 TMA 描述符,而不是传递全局内存指针,CUTLASS 内核避免了三个额外的 H2D 复制内核,相反,这些复制被包含在 GEMM 的单个设备内核启动中。

由于描述符移动到设备的方式不同,包括准备张量以供 TMA 消耗的时间在内的内核延迟也大相径庭。对于 M=1-128, N=4096, K=4096,CUTLASS pingpong 内核的平均延迟为 10us,而 Triton TMA 内核平均完成时间为 4ms。这慢了约 3330 倍,并且似乎与 Triton 独立启动的 3 个 TMA 描述符传输内核直接相关。

Cuda graphs 可能是减少这种情况的一种方法,但考虑到 H2D 复制造成的开销,当前的 Triton 实现在端到端测量时并不具有竞争力。对 Triton 编译器管理 TMA 描述符的方式进行重构可能会解决这个差距。因此,我们上面的数据主要关注比较实际的计算内核吞吐量,而不是端到端性能。

结果摘要

Triton FP8 TMA GEMM TFLOPs Comparison

图 10. Triton FP8 TMA GEMM TFLOPs 比较

MTriton TMATriton 教程Triton SplitKcuBLAS FP8cuBLAS FP16CUTLASS Ping-Pong FP8
12.512.41.51.83.57
25.12.54.83.13.65.9
410.37.219.66.17.214.3
821.016.519.212.314.428.6
1644.541.037.224.527.755.1
3289.781.272.271.656.8114.4
64178.5163.7130.8144.6105.3228.7
128359.7225.9160.1244.0189.2377.7

图 11. Triton FP8 TMA GEMM TFLOPs 比较表

上图和表格总结了我们通过利用 TMA 硬件单元在单个 NVIDIA H100 上为 FP8 GEMM 实现的增益,超过了非 TMA Triton 内核和高性能 CUDA (cuBLAS) 内核。需要注意的是,该内核的卓越扩展性(随批次大小)优于竞争对手。我们测试的问题规模代表了中小型批次大小 LLM 推理中发现的矩阵形状。因此,对于那些有兴趣将此内核用于 FP8 LLM 部署用例的人来说,中 M 范围(M=32 到 M=128)的 TMA GEMM 内核性能将至关重要,因为 FP8 压缩数据类型可以允许更大的矩阵适应 GPU 内存。

总结我们的分析,Triton 和 CUTLASS 中的 TMA 实现在完整功能集支持(多播、预取等)以及 TMA 描述符如何传递给 GPU 内核方面存在差异。如果此描述符以更接近 CUTLASS 内核的方式(传值)传递,则可以避免多余的 H2D 复制,从而大大提高 E2E 性能。

未来工作

在未来的研究中,我们计划通过与社区合作,将 CUTLASS 的 TMA 加载架构整合到 Triton 中,并研究 FP8 GEMM 的协作内核(Ping-Pong 内核的修改策略),以改进这些结果。

此外,一旦 Triton 中启用了线程块集群和 TMA 原子操作等功能,我们可能会通过在 TMA GEMM 内核中利用 SplitK 策略获得进一步的加速,因为 Hopper 上的原子操作可以在分布式共享内存 (DSMEM) 而不是 L2 缓存中执行。我们还注意到 NVIDIA Hopper GPU 与其他 AI 硬件加速器(如 Google 的TPU 和 IBM 的AIU)的相似之处,它们都是数据流架构。在 Hopper 上,由于 TMA 的加入(我们在本博客中广泛讨论),数据现在可以从 GMEM“流”向连接的 SM 网络,我们计划在未来的文章中介绍 DSMEM。