博客

使用 TLX 启用集群启动控制

什么是集群启动控制 (CLC)?

Blackwell 架构引入了集群启动控制 (CLC) 以实现动态调度。该功能允许内核根据需要启动任意数量的线程块 (threadblock),这既镜像了非持久化内核 (non-persistent kernels) 的处理方式,又保留了持久化内核带来的优势——即更少的线程块启动次数和由硬件驱动的负载均衡。

我们从一个简单的 GEMM 内核开始,它使用 32×32 的输出瓦片 (output tiles),且有 144 个 SM 可用。

图 1. 非持久化调度

启用 CLC 后,从主机启动 32×32 的网格 (grid) 会将 CTA 0–143 最初分配给 SM 0–143。

图 2. CLC 将初始 CTA 分配给 SM

例如,当 CTA 0 仍在 SM 0 上运行时,CLC 允许 SM 0 异步且原子地“窃取”下一个可用任务(例如 CTA #200),这样 SM 0 就可以立即开始处理块 #200,而无需进行新的线程块启动。

图 3. CLC 窃取工作

动态调度允许系统在执行过程中适应不断变化的工作负载和资源可用性。例如,如果运行时有 5 个额外的 SM 变得可用,它们也可以窃取并处理可用的工作。

https://docs.nvda.net.cn/cutlass/media/docs/cpp/blackwell_cluster_launch_control.html#blackwell-cluster-launch-control 

什么是 TLX?

TLX 是 Triton DSL 的底层扩展,专为需要对 GPU 操作进行细粒度控制的专家用户设计。TLX 提供:

  • 硬件特定的内建函数(如 wgmma、async_copy 和 barrier)
  • 共享内存和本地内存管理
  • 指令级调度与控制
  • 跨线程束组 (warpgroup) 同步

这些功能通过暴露底层的 GPU 原语以及用于内存、计算和异步控制流的显式构造,实现了高级内核开发。虽然 TLX 目前专注于 NVIDIA GPU,但它允许用户实现特定于架构的优化,从而减少对编译器启发式算法的依赖。这种方法赋予了用户更多的责任和灵活性,但也可能导致不同硬件平台之间的差异性增大。

https://github.com/facebookexperimental/triton/tree/main 

TLX 中的 CLC

TLX 提供了三个 CLC API

  1. 初始化 tlx.clc_create_context(num_stages, num_consumers):为 CLC 分配共享内存。
    1. num_stages:启用流水线工作窃取。

    2. num_consumers:支持多消费者。

  2. 生产者 tlx.clc_producer(context, k, p_producer):尝试窃取一个工作阶段。
    1. context:由 clc_create_context 返回的句柄。

    2. k:阶段索引(0 到 num_stages-1)。

    3. p_producer:mbarrier 奇偶校验相位 (parity phase)。

  3. 消费者 tlx.clc_consumer(context, k, p_consumer):用于 CTA ID 解码(如果成功)。
    1. k:阶段索引。
    2. p_consumer:消费者的 mbarrier 奇偶校验相位。

初始化 API tlx.clc_create_context 同时启用了多阶段流水线和多消费者工作流。CLC 生产者-消费者设置要求在共享内存中,每个阶段包含一对 mbarrier(mbar_emptymbar_full)以及一个 CLC 响应对象。

生产者 API 将通过等待 mbar_empty 来获取任务,并通过 try_cancelmbar_full 来提交。消费者 API 将等待 mbar_full,从 CLC 响应中解码瓦片 ID,并释放 mbar_empty

# init
clc_context = tlx.clc_create_context(NUM_CLC_STAGES, 1) # only 1 CLC consumer


# init mbar parity phases
clc_phase_producer = 1
clc_phase_consumer = 0
# cicular-buffer pipeline counter
clc_buf = 0


tile_id = start_pid
while tile_id != -1:
clc_buf = clc_buf % NUM_CLC_STAGES
# producer: steal workload
tlx.clc_producer(clc_context, clc_buf, clc_phase_producer)
clc_phase_producer = clc_phase_producer ^ (clc_buf == (NUM_CLC_STAGES - 1))
... # main

# consumer: decode CTA ID
tile_id = tlx.clc_consumer(clc_context, clc_buf, clc_phase_consumer)
clc_phase_consumer = clc_phase_consumer ^ (clc_buf == (NUM_CLC_STAGES - 1))
clc_buf += 1

案例研究

比较 WS GEMMCLC+WS GEMM,两者均使用 3 个 WS 区域(对比结果)。

  • 默认 WG(尾声消费者):调用 tlx.clc_producertlx.clc_consumer

图 4. 在 tlx.async_tasks 外部初始化上下文,并在 ws-region 中调用生产者 API。

图 5. 在尾声 (epilogue) ws-region 中调用消费者 API。

  • 非默认 WG(MMA 消费者):仅调用 tlx.clc_consumer

图 6. 在 MMA ws-region 中调用消费者 API。

  • 非默认 WG(生产者,TMA 加载):仅调用 tlx.clc_consumer

图 7. 在 TMA 加载 ws-region 中调用消费者 API。

图 8. 镜像非持久化内核中使用的网格大小。

可视化流水线 GEMM 与 CLC GEMM 之间的差异

  • Y 轴:代表 144 个 SM,每个由其 SM ID(从 0 到 143)标识。
  • X 轴:代表时间(以时钟周期计),涵盖工作负载的持续时间。
  • 热力图大部分区域为黄色,意味着在大多数时钟周期内,SM 都被线程块占用。
  • CLC 通过消除流水线 GEMM 中的空闲间隙(紫色区域)实现了更好的性能。

图 9. 流水线 GEMM 与 CLC GEMM 之间的 SM 占用热力图对比。

  • 由于上述 GEMM 示例中所有线程块处理的工作负载大小相同,CLC 并未提升负载均衡。但对于线程块之间工作负载不均匀的内核,CLC 将极大改善负载均衡,如下图所示。

图 10. 启用 CLC 后内部内核的 SM 占用热力图。

致谢

衷心感谢 Bingyi Zhang (Nvidia) 关于 CLC 的启发性讨论,以及 Srivatsan Ramesh (Meta) 和 Yuanwei (Kevin) Fang (Meta) 在生成 SM 占用热力图方面提供的工具支持。