跳转到主要内容
博客

在 PyTorch 中启用高级 GPU 功能 – Warp 专门化

作者: 2025 年 2 月 5 日2025 年 5 月 3 日暂无评论

Meta:Hongtao Yu、Manman Ren、Bert Maher、Shane Nay
NVIDIA:Gustav Zhu、Shuhao Jiang

在过去的几个月里,我们一直致力于通过 Triton 编译器为 PyTorch 和 Triton 用户启用高级 GPU 功能。我们的主要目标之一是为 NVIDIA Hopper GPU 引入 warp 专用化支持。今天,我们很高兴地宣布,我们的努力已成功推出全自动 Triton warp 专用化,用户可以在即将发布的 Triton 3.2 版本中体验此功能,该版本将随 PyTorch 2.6 一同发布。PyTorch 用户可以通过实现用户定义的 Triton 内核来利用此功能。这项工作利用了 NVIDIA 在 Triton 中对 warp 专用化的初始实现,我们期待未来与社区进一步发展。

Warp 专用化 (WS) 是一种 GPU 编程技术,其中线程块内的 warp(NVIDIA GPU 上由 32 个线程组成的一组)被分配不同的角色或任务。这种方法通过实现需要任务区分或协作处理的工作负载的高效执行来优化性能。它通过利用异步执行模型来增强内核性能,其中内核的不同部分由独立的硬件单元管理。这些单元之间的数据通信通过 NVIDIA H100 上的共享内存进行,效率非常高。与统一 warp 方法相比,warp 专用化允许硬件多任务 warp 调度器更有效地运行,最大限度地提高资源利用率和整体性能。

以 GEMM 为例,H100 GPU 上典型的统一 warp 方法涉及每个线程块 8 个 warp,它们共同计算输出张量的一个瓦片。这 8 个 warp 分为两个 warp 组 (WG),每个组使用高效的 warp 组级 MMA (WGMMA) 指令协同计算瓦片的一半,如图 1 所示。

Figure 1. GEMM K-loop Body with Uniform Warps

图 1. 带有统一 Warps 的 GEMM K 循环体

由于优雅的软件流水线,该实现简洁、易于理解,并且通常表现良好。流水线器的目的是通过在不同的硬件单元上执行不相关的操作来增强指令级并行性。例如,来自下一个循环迭代的加载操作可以与当前迭代中的 WGMMA 操作同时执行。然而,这种方法严重依赖编译器来精心设计指令序列,以确保加载和 WGMMA 指令在恰好正确的时间发出。虽然这对于涉及有限数量操作的 GEMM 来说相对简单,但对于更复杂的内核(例如 Flash Attention)来说,它变得更具挑战性。

另一方面,warp 专用化通过将旨在同时在不同硬件单元上运行的操作分离到不同的 warp 中,并使用共享内存中的低成本屏障高效地同步它们来解决编程挑战。这允许每个 warp 拥有自己的指令序列,从而使指令能够连续发出和执行,而不会被其他操作中断,这得益于多路 warp 调度器。图 2 显示了 warp 专用化 GEMM 的示例。

Figure 2. GEMM K-loop Body with Specialized Warps

图 2. 带有专用 Warps 的 GEMM K 循环体

如何启用 WS

要启用 warp 专用化,用户只需指定两个自动调优标志:num_consumer_groups 和 num_buffers_warp_spec。例如,warp 专用化 GEMM 实现可能如下所示。用户可以通过为 num_consumer_groups 设置非零值来启用 warp 专用化,该值定义了消费者 warp 组的数量。目前没有相应的标志来设置生产者 warp 组的数量,因为目前只支持一个生产者。num_buffers_warp_spec 标志指定生产者 warp 组将用于与消费者 warp 组通信的缓冲区数量。持久化 GEMM 教程中提供了 warp 专用化内核的有效示例。

@triton.autotune(
    configs=[
        triton.Config(
            {
                "BLOCK_SIZE_M": 128,
                "BLOCK_SIZE_N": 256,
                "BLOCK_SIZE_K": 64,
                "GROUP_SIZE_M": 8,
            },
            num_stages=2,
            num_warps=4,
            num_consumer_groups=2,
            num_buffers_warp_spec=3,
        ),
    ],
    key=["M", "N", "K"],
)
@triton.jit
def matmul_persistent_ws_kernel(
   a_ptr, b_ptr, c_ptr, M, N, K,
   stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,
   BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
   pid = tl.program_id(axis=0)
   num_pid_m = tl.cdiv(M, BLOCK_M)
   num_pid_n = tl.cdiv(N, BLOCK_N)
   pid_m = pid // num_pid_m
   pid_n = pid % num_pid_n
   offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
   offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
   offs_k = tl.arange(0, BLOCK_K)
   a_ptrs = a_ptr + (offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak)
   b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn)
   acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
   for k in range(0, tl.cdiv(K, BLOCK_K)):
       a = tl.load(a_ptrs)
       b = tl.load(b_ptrs)
       acc += tl.dot(a, b)
       a_ptrs += BLOCK_K * stride_ak
       b_ptrs += BLOCK_K * stride_bk
   c = acc.to(tl.float16)
   c_ptrs = c_ptr + stride_cm * offs_m[:, None] + stride_cn * offs_n[None, :]
   tl.store(c_ptrs, c)

幕后

Warp 专用化使用一组分层编译器转换和 IR 更改,将用户的非 warp 专用化内核转换为 warp 专用化机器代码。这些包括:

  • 任务划分:整个内核根据预定义的启发式方法自动划分为异步任务。编译器决定如何利用一个生产者 warp 组和用户指定的消费者 warp 组数量来执行内核。它将任务 ID 分配给特定的锚定操作,然后通过异步任务 ID 传播和依赖分析影响其余操作的任务分配。由于共享内存是所有支持的平台之间 warp 组间数据传输最有效的方法,因此编译器优化任务划分以最大程度地减少寄存器溢出到共享内存,确保高效执行。
  • 多消费者组的数据划分:在多个消费者组之间高效地划分数据是优化工作负载分布的关键。在 H100 GPU 上,编译器默认尝试沿 M 维度划分输入张量 A,允许每个消费者组独立计算一半的输出张量。这种策略,称为协作划分,在大多数情况下都能最大限度地提高效率。然而,如果这种划分导致效率低下(例如,生成的工作负载小于原生的 WGMMA 指令大小),编译器会动态调整并改为沿 N 维度划分。
  • 数据流流水线:编译器创建循环共享内存缓冲区,以在多维循环中流水线数据流。Warp 专用化流水线支持复杂的控制流。例如,我们的 warp 专用化持久化 GEMM 内核使用双嵌套循环,允许生产者在消费者完成前一个瓦片的计算时开始获取下一个输出瓦片的数据。
  • 通信操作:我们引入了四个高级 Triton GPU IR (TTGIR) 通信操作——ProducerAcquireOpProducerCommitOpConsumerWaitOpConsumerReleaseOp——来管理流水线数据流。这些操作支持 TMA 和非 TMA 内存操作。
  • 代码划分:每个异步任务都被分解成自己的独立代码区域,并由 warp 组 ID 检查进行保护。控制依赖关系根据需要进行复制。
  • TTGIR 到 LLVM/PTX 实例化:TTGIR 通信操作被实例化为相应的 LLVM/PTX 屏障操作。

性能

warp 专用化版本引入了一系列 Triton 编译器转换,它们共同将用户代码转换为 warp 专用化内核。此功能已应用于多个关键内核,包括 Flash Attention 和 FP8 行式 GEMM,从而带来了 10% 到 15% 的显著性能提升。下面,我们重点介绍了这些高影响力内核的最新性能指标。

bar chart
bar chart

未来工作

展望未来,我们计划通过引入新功能,例如 Ping-Pong 调度、扩展的缓冲区共享支持、改进的 TMA 透明处理以及针对即将推出的 NVIDIA 硬件的优化分区启发式方法,进一步增强 Triton 的 warp 专用化支持。