Meta:Hongtao Yu, Manman Ren, Bert Maher, Shane Nay
NVIDIA:Gustav Zhu, Shuhao Jiang
在过去的几个月里,我们一直致力于通过 Triton 编译器为 PyTorch 和 Triton 用户启用高级 GPU 功能。我们的核心目标之一是在 NVIDIA Hopper GPU 上引入 Warp 特化 (Warp Specialization) 支持。今天,我们很高兴地宣布,我们的工作已经实现了完全自动化的 Triton Warp 特化,并将随即将发布的 Triton 3.2 版本(随 PyTorch 2.6 发布)提供给用户。PyTorch 用户可以通过实现自定义的 Triton 内核来利用此功能。这项工作借鉴了 NVIDIA 在 Triton 中对 Warp 特化的初步实现,我们期待在未来与社区共同进一步完善该功能。
Warp 特化 (WS) 是一种 GPU 编程技术,它将线程块 (threadblock) 内的 Warp(NVIDIA GPU 上一组 32 个线程)分配给不同的角色或任务。这种方法通过实现任务差异化或协作处理,高效地执行工作负载,从而优化了性能。它通过利用异步执行模型来增强内核性能,其中内核的不同部分由独立的硬件单元管理。在 NVIDIA H100 上,通过共享内存进行的单元间数据通信非常高效。与统一 Warp 方法相比,Warp 特化使硬件多任务 Warp 调度器能够更有效地运行,从而最大限度地提高资源利用率和整体性能。
以 GEMM 为例,在 H100 GPU 上典型的统一 Warp 方法涉及每个线程块 8 个 Warp,共同计算输出张量的一个块 (tile)。这 8 个 Warp 被划分为两个 Warp 组 (WG),每个组使用高效的 Warp 组级 MMA (WGMMA) 指令协同计算半个块,如图 1 所示。

图 1. 具有统一 Warp 的 GEMM K 循环体
得益于优雅的软件流水线技术,这种实现方式清晰、易于理解,且性能普遍良好。流水线的作用是通过在不同的硬件单元上执行非相关操作来增强指令级并行性。例如,来自下一次循环迭代的加载操作可以与当前迭代中的 WGMMA 操作同时执行。然而,这种方法高度依赖编译器来精确构建指令序列,以确保加载和 WGMMA 指令在准确的时间发出。虽然这对于涉及操作次数有限的 GEMM 来说相对简单,但对于 Flash Attention 等更复杂的内核,挑战则显著增加。
另一方面,Warp 特化通过将旨在不同硬件单元上同时运行的操作分离到不同的 Warp 中,并使用共享内存中的低成本屏障高效同步,从而解决了编程挑战。这使得每个 Warp 都能拥有自己的指令序列,得益于多路 Warp 调度器,指令可以持续发出和执行,而不受其他操作的中断。图 2 展示了 Warp 特化后的 GEMM 示意图。

图 2. 具有特化 Warp 的 GEMM K 循环体
如何启用 WS
要启用 Warp 特化,用户只需指定两个自动调优标志:num_consumer_groups 和 num_buffers_warp_spec。例如,一个 Warp 特化后的 GEMM 实现如下所示。用户可以通过设置非零的 num_consumer_groups(定义消费者 Warp 组的数量)来启用 Warp 特化。目前没有对应的标志来设置生产者 Warp 组的数量,因为目前仅支持一个生产者。num_buffers_warp_spec 标志指定了生产者 Warp 组与消费者 Warp 组进行通信所使用的缓冲区数量。在持久化 GEMM 教程中提供了一个可用的 Warp 特化内核示例。
@triton.autotune(
configs=[
triton.Config(
{
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
},
num_stages=2,
num_warps=4,
num_consumer_groups=2,
num_buffers_warp_spec=3,
),
],
key=["M", "N", "K"],
)
@triton.jit
def matmul_persistent_ws_kernel(
a_ptr, b_ptr, c_ptr, M, N, K,
stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_M)
num_pid_n = tl.cdiv(N, BLOCK_N)
pid_m = pid // num_pid_m
pid_n = pid % num_pid_n
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, BLOCK_K)
a_ptrs = a_ptr + (offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_K)):
a = tl.load(a_ptrs)
b = tl.load(b_ptrs)
acc += tl.dot(a, b)
a_ptrs += BLOCK_K * stride_ak
b_ptrs += BLOCK_K * stride_bk
c = acc.to(tl.float16)
c_ptrs = c_ptr + stride_cm * offs_m[:, None] + stride_cn * offs_n[None, :]
tl.store(c_ptrs, c)
底层原理
Warp 特化使用一系列分层编译器转换和 IR 变更,将用户非特化的内核转换为特化后的机器代码。这些包括:
- 任务分区 (Task Partitioning):整个内核根据预定义的启发式规则自动划分为异步任务。编译器决定如何利用一个生产者 Warp 组和用户指定数量的消费者 Warp 组来执行内核。它为特定的锚点操作分配任务 ID,这些 ID 随后通过异步任务 ID 传播和依赖分析影响剩余操作的任务分配。由于共享内存是所有支持平台上 Warp 组之间进行数据传输最有效的方法,编译器会优化任务分区,以尽量减少向共享内存的寄存器溢出,从而确保高效执行。
- 多消费者组的数据分区 (Data Partitioning for Multiple Consumer Groups):在多个消费者组之间高效划分数据是优化工作负载分配的关键。在 H100 GPU 上,编译器默认尝试沿
M维度对输入张量A进行分区,允许每个消费者组独立计算输出张量的一半。这种称为协作分区的策略在大多数情况下都能实现效率最大化。然而,如果这种拆分导致效率低下(例如产生的工作负载小于原生 WGMMA 指令大小),编译器将动态调整并改为沿N维度进行分区。 - 数据流流水线 (Dataflow Pipelining):编译器创建循环共享内存缓冲区,以在多维循环中流水线化数据流。Warp 特化流水线支持复杂的控制流。例如,我们的 Warp 特化持久化 GEMM 内核使用了双重嵌套循环,允许生产者在消费者完成前一个块的计算时,开始为下一个输出块获取数据。
- 通信操作 (Communication Operations):我们引入了四个高级 Triton GPU IR (TTGIR) 通信操作——
ProducerAcquireOp、ProducerCommitOp、ConsumerWaitOp和ConsumerReleaseOp——用于管理流水线化的数据流。这些操作同时支持 TMA 和非 TMA 内存操作。 - 代码分区 (Code Partitioning):每个异步任务都被划分到其独立的区域,并由 Warp 组 ID 检查进行保护。控制依赖项根据需要进行复制。
- TTGIR 到 LLVM/PTX 的物化 (TTGIR to LLVM/PTX Materialization):TTGIR 通信操作被物化为相应的 LLVM/PTX 屏障操作。
性能
此次Warp 特化发布引入了一系列 Triton 编译器转换,将用户代码转换为 Warp 特化内核。该功能已应用于多个关键内核,包括 Flash Attention 和 FP8 行式 GEMM,实现了 10% 到 15% 的性能显著提升。以下我们列出了这些高影响力内核的最新性能指标。


未来工作
展望未来,我们计划通过引入 Ping-Pong 调度、扩展缓冲区共享支持、改进对 TMA 的透明处理以及针对即将推出的 NVIDIA 硬件优化分区启发式规则等新功能,进一步增强 Triton 的 Warp 特化支持。