作者:Meta 和 NVIDIA

Meta:于洪涛, 任曼曼, Bert Maher, Shane Nay
NVIDIA:Gustav Zhu, 蒋书豪

在过去的几个月里,我们一直致力于通过 Triton 编译器为 PyTorch 和 Triton 用户启用高级 GPU 功能。我们的一个主要目标是在 NVIDIA Hopper GPU 上引入 warp specialization 支持。今天,我们很高兴地宣布,我们的努力已促成全自动 Triton warp specialization 的推出,用户将在 Triton 3.2 即将发布的版本中使用此功能,该版本将随 PyTorch 2.6 一起发布。PyTorch 用户可以通过实现用户定义的 Triton 内核来利用此功能。这项工作利用了 NVIDIA 在 Triton 中对 warp specialization 的初步实现,我们期待未来与社区进一步合作开发。

Warp specialization (WS) 是一种 GPU 编程技术,其中 threadblock 内的 warps(NVIDIA GPU 上的 32 个线程组)被分配不同的角色或任务。这种方法通过实现需要任务区分或协作处理的工作负载的高效执行来优化性能。它通过利用异步执行模型来增强内核性能,在该模型中,内核的不同部分由独立的硬件单元管理。这些单元之间的数据通信,通过 NVIDIA H100 上的共享内存实现,效率极高。与统一 warp 方法相比,warp specialization 允许硬件多任务 warp 调度器更有效地工作,从而最大限度地提高资源利用率和整体性能。

以 GEMM 为例,H100 GPU 上典型的统一 warp 方法涉及每个 thread block 中的 8 个 warps 共同计算输出张量的一个 tile。这 8 个 warps 被分成两个 warp 组 (WG),每个组使用高效的 warp 组级 MMA (WGMMA) 指令协作计算 tile 的一半,如图 1 所示。

Figure 1. GEMM K-loop Body with Uniform Warps

图 1. 采用统一 Warps 的 GEMM K-loop 主体

得益于精妙的软件流水线,该实现简洁、易于理解,并且通常性能良好。流水线器的目的是通过在不同的硬件单元上执行非依赖操作来增强指令级并行性。例如,来自下一个循环迭代的加载操作可以与当前迭代中的 WGMMA 操作同时执行。然而,这种方法严重依赖编译器来精心设计指令序列,以确保加载和 WGMMA 指令在恰好正确的时间发出。虽然这对涉及有限数量操作的 GEMM 相对简单,但对于更复杂的内核(例如 flash attention)来说,则变得更加具有挑战性。

另一方面,warp specialization 通过将旨在在不同硬件单元上同时运行的操作分离到不同的 warps 中来解决编程挑战,并使用共享内存中的低成本屏障有效地同步它们。这使得每个 warp 都有自己的指令序列,借助多路 warp 调度器,指令可以连续发出和执行,而不会被其他操作中断。warp-specialized GEMM 的图解可以在图 2 中看到。

Figure 2. GEMM K-loop Body with Specialized Warps

图 2. 采用 Specialized Warps 的 GEMM K-loop 主体

如何启用 WS

要启用 warp specialization,用户只需指定两个 autotune 标志:num_consumer_groups 和 num_buffers_warp_spec。例如,一个 warp-specialized GEMM 实现可能如下所示。用户可以通过为 num_consumer_groups 设置非零值来启用 warp specialization,该值定义了消费者 warp 组的数量。目前没有对应的标志来设置生产者 warp 组的数量,因为目前只支持一个生产者。num_buffers_warp_spec 标志指定了生产者 warp 组将用于与消费者 warp 组通信的 buffer 数量。在 persistent GEMM 教程中提供了一个 warp-specialized 内核的工作示例。

@triton.autotune(
    configs=[
        triton.Config(
            {
                "BLOCK_SIZE_M": 128,
                "BLOCK_SIZE_N": 256,
                "BLOCK_SIZE_K": 64,
                "GROUP_SIZE_M": 8,
            },
            num_stages=2,
            num_warps=4,
            num_consumer_groups=2,
            num_buffers_warp_spec=3,
        ),
    ],
    key=["M", "N", "K"],
)
@triton.jit
def matmul_persistent_ws_kernel(
   a_ptr, b_ptr, c_ptr, M, N, K,
   stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,
   BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
   pid = tl.program_id(axis=0)
   num_pid_m = tl.cdiv(M, BLOCK_M)
   num_pid_n = tl.cdiv(N, BLOCK_N)
   pid_m = pid // num_pid_m
   pid_n = pid % num_pid_n
   offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
   offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
   offs_k = tl.arange(0, BLOCK_K)
   a_ptrs = a_ptr + (offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak)
   b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn)
   acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
   for k in range(0, tl.cdiv(K, BLOCK_K)):
       a = tl.load(a_ptrs)
       b = tl.load(b_ptrs)
       acc += tl.dot(a, b)
       a_ptrs += BLOCK_K * stride_ak
       b_ptrs += BLOCK_K * stride_bk
   c = acc.to(tl.float16)
   c_ptrs = c_ptr + stride_cm * offs_m[:, None] + stride_cn * offs_n[None, :]
   tl.store(c_ptrs, c)

幕后原理

Warp specialization 使用一组分层编译器变换和 IR 变更,将用户的非 warp-specialized 内核转换为 warp-specialized 机器码。其中包括

  • 任务划分:整个内核根据预定义的启发式方法自动划分为异步任务。编译器决定如何利用一个生产者 warp 组和用户指定数量的消费者 warp 组来执行内核。它为特定的锚点操作分配任务 ID,然后通过异步任务 ID 传播和依赖分析影响剩余操作的任务分配。由于共享内存是所有支持平台上 warp 组之间数据传输最有效的方法,编译器会优化任务划分,以最大限度地减少寄存器溢出到共享内存,从而确保高效执行。
  • 多消费者组数据划分:在多个消费者组之间高效划分数据是优化工作负载分布的关键。在 H100 GPU 上,编译器默认尝试沿 M 维度划分输入张量 A,允许每个消费者组独立计算输出张量的一半。这种策略,被称为协作划分,在大多数条件下都能最大化效率。但是,如果这种分割导致效率低下——例如生成的工作负载小于原生的 WGMMA 指令大小——编译器会动态调整并改为沿 N 维度进行划分。
  • 数据流流水线:编译器创建循环共享内存 buffer,以在多维循环中流水化数据流。Warp-specialized 流水线支持复杂的控制流。例如,我们的 warp-specialized persistent GEMM 内核使用双层嵌套循环,允许生产者在消费者完成先前 tile 的计算时开始获取下一个输出 tile 的数据。
  • 通信操作我们引入了四种高级 Triton GPU IR (TTGIR) 通信操作—ProducerAcquireOp, ProducerCommitOp, ConsumerWaitOp, ConsumerReleaseOp—来管理流水线数据流。这些操作支持 TMA 和非 TMA 内存操作。
  • 代码划分:每个异步任务都被概括为一个独立的 代码区域,并由 warp 组 ID 检查保护。根据需要复制控制依赖关系。
  • TTGIR 到 LLVM/PTX 具体化:TTGIR 通信操作被具体化为相应的 LLVM/PTX 屏障操作。

性能

此次 warp specialization 发布引入了一系列 Triton 编译器变换,它们共同将用户代码转换为 warp-specialized 内核。此功能已应用于几个关键内核,包括 Flash Attention 和 FP8 行式 GEMM,从而带来了 10% 到 15% 的显著性能提升。下面,我们重点介绍这些高影响内核的最新性能指标。

bar chart

bar chart

未来工作

展望未来,我们计划通过引入新功能(如 Ping-Pong 调度、扩展的 buffer 共享支持、改进的 TMA 透明处理、针对即将推出的 NVIDIA 硬件的优化划分启发式方法)来进一步增强 Triton 的 warp specialization 支持。