Triton 开源编程语言和编译器提供了一种基于 Python 的高级方法来创建高效的 GPU 代码。在这篇博客中,我们将重点介绍 Triton 程序如何编译以及中间表示形式的底层细节。关于 Triton 的介绍,请读者参考这篇博客

Triton 语言与编译

Triton 编程语言支持多种类型的现代 GPU,并遵循块状(blocked)编程方法。作为示例,我们将参考Triton 向量加法教程并进行微调。向量加法内核及辅助函数定义如下:

import torch
import triton
import triton.language as tl

@triton.jit
def add_kernel(x_ptr,  # *Pointer* to first input vector.
               y_ptr,  # *Pointer* to second input vector.
               output_ptr,  # *Pointer* to output vector.
               n_elements, 
               BLOCK_SIZE: tl.constexpr, 
               ):
  
    pid = tl.program_id(axis=0) 
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
 
    mask = offsets < n_elements

    x = tl.load(x_ptr + offsets, mask=mask)
    y = tl.load(y_ptr + offsets, mask=mask)
    output = x + y
    tl.store(output_ptr + offsets, output, mask=mask)
 
def add(x: torch.Tensor, y: torch.Tensor):
    output = torch.empty_like(x)
    assert x.is_cuda and y.is_cuda and output.is_cuda
    n_elements = output.numel()

    grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )
    triton_kernel=add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
    torch.cuda.synchronize()

    # Save compilation stages - some of the stages identified here are specific to NVIDIA devices:
    with open('triton_IR.txt', 'w') as f:
        print(triton_kernel.asm['ttir'], file=f)
    with open('triton_TTGIR.txt', 'w') as f:
        print(triton_kernel.asm['ttgir'], file=f)
    with open('triton_LLVMIR.txt', 'w') as f:
        print(triton_kernel.asm['llir'], file=f)
    with open('triton_PTX.ptx', 'w') as f:
        print(triton_kernel.asm['ptx'], file=f)
    with open('triton_cubin.txt', 'w') as f:
        print(triton_kernel.asm['cubin'], file=f)

    return output

torch.manual_seed(0)
size = 98432
x = torch.rand(size, device='cuda')
y = torch.rand(size, device='cuda')
output_torch = x + y
output_triton = add(x, y)
print(output_torch)
print(output_triton)
print(f'The maximum difference between torch and triton is '
      f'{torch.max(torch.abs(output_torch - output_triton))}')    

Triton 向量加法内核包含 @triton.jit 装饰器。Triton 编译器将编译由 @triton.jit 标记的函数,通过多个编译阶段对其进行降级(lowering)。辅助函数 add 用于分配输出张量、计算合适的 GPU 网格大小,并额外保存中间编译阶段的内容。

聚焦于编译过程,Triton 内核通过下图中概述的一系列阶段被降级为特定设备的汇编代码。

flow diagram

内核的编译首先通过遍历已装饰的 Python 函数的抽象语法树(AST)来创建 Triton 中间表示(Triton-IR)。Triton-IR 是一种未经优化、与机器无关的中间表示。它引入了块级(tile-level)编程需求,并基于开源 LLVM 编译器项目。接下来,Triton 编译器会对 Triton-IR 进行优化并将其转换为 Triton-GPU IR(Triton-TTGIR)阶段,随后转换为 LLVM-IR。Triton-IR 和 Triton-GPUIR 这两种表示形式均编写为 MLIR 方言,MLIR 是 LLVM 的一个子项目,旨在改善异构硬件的编译效率。

对于 Triton 向量加法教程内核,示例 Triton IR 片段如下:

module {
  tt.func public @add_kernel(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("/u/saraks/triton_blog/01-vector-add.py":28:0), %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("/u/saraks/triton_blog/01-vector-add.py":28:0), %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("/u/saraks/triton_blog/01-vector-add.py":28:0), %arg3: i32 {tt.divisibility = 16 : i32} loc("/u/saraks/triton_blog/01-vector-add.py":28:0)) attributes {noinline = false} {
    %c1024_i32 = arith.constant 1024 : i32 loc(#loc1)
    %0 = tt.get_program_id x : i32 loc(#loc2)
    %1 = arith.muli %0, %c1024_i32 : i32 loc(#loc3)
    %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> loc(#loc4)
    %3 = tt.splat %1 : i32 -> tensor<1024xi32> loc(#loc5)
    %4 = arith.addi %3, %2 : tensor<1024xi32> loc(#loc5)
    %5 = tt.splat %arg3 : i32 -> tensor<1024xi32> loc(#loc6)
    %6 = arith.cmpi slt, %4, %5 : tensor<1024xi32> loc(#loc6)
    %7 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>> loc(#loc7)
    %8 = tt.addptr %7, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32> loc(#loc7)
    %9 = tt.load %8, %6 : tensor<1024x!tt.ptr<f32>> loc(#loc8)
    %10 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>> loc(#loc9)
    %11 = tt.addptr %10, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32> loc(#loc9)
    %12 = tt.load %11, %6 : tensor<1024x!tt.ptr<f32>> loc(#loc10)
    %13 = arith.addf %9, %12 : tensor<1024xf32> loc(#loc11)
    %14 = tt.splat %arg2 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>> loc(#loc12)
    %15 = tt.addptr %14, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32> loc(#loc12)
    tt.store %15, %13, %6 : tensor<1024x!tt.ptr<f32>> loc(#loc13)
    tt.return loc(#loc14)
  } loc(#loc)
} loc(#loc)

请注意,Triton 内核中的主要函数现在表示为:

Triton 内核 Triton IR
x = tl.load(x_ptr + offsets, mask=mask) %9 = tt.load %8, %6 : tensor<1024x!tt.ptr<f32>> loc(#loc8)
y = tl.load(y_ptr + offsets, mask=mask) %12 = tt.load %11, %6 : tensor<1024x!tt.ptr<f32>> loc(#loc10)
output = x + y %13 = arith.addf %9, %12 : tensor<1024xf32> loc(#loc11)
tl.store(output_ptr + offsets, output, mask=mask) tt.store %15, %13, %6 : tensor<1024x!tt.ptr<f32>> loc(#loc13)

在 Triton IR 阶段,%arg0: !tt.ptr<f32> 以及随后的张量引用表明,中间表示已经根据数据类型进行了专门化。

我们在配备 CUDA 12.2、Python 3.11.9 和 PyTorch 2.4.1(使用随 PyTorch 安装的默认 Triton 版本)的 Tesla V100-SXM2-32GB GPU 上运行了此示例。在该设备上,简单的向量加法具有以下 Triton GPU IR 片段(为清晰起见省略了部分行):

#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:70", "triton_gpu.threads-per-warp" = 32 : i32} {
  tt.func public @add_kernel(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}
    ⋮
    %9 = tt.load %8, %6 : tensor<1024x!tt.ptr<f32>, #blocked> loc(#loc8)
    ⋮
    %12 = tt.load %11, %6 : tensor<1024x!tt.ptr<f32>, #blocked> loc(#loc10)
    %13 = arith.addf %9, %12 : tensor<1024xf32, #blocked> loc(#loc11)
    ⋮
    tt.store %15, %13, %6 : tensor<1024x!tt.ptr<f32>, #blocked> loc(#loc13)
    ⋮
  } loc(#loc)
} loc(#loc)

在此阶段,包含了一些特定于硬件的信息。例如,计算能力(compute capability)以及张量如何分布到核心(cores)和线程束(warps)上的详细信息(对于 AMD GPU,则是 wavefronts)。在此示例中,张量表示为 #blocked 布局。在此编码中,每个线程束拥有张量的一个连续部分。目前,其他可能的内存优化包括 slice(沿维度重构和分布张量)、dot_op(针对块矩阵乘法的优化布局)、shared(指示 GPU 共享内存)、nvidia_mma(由 NVIDIA 张量核心产生)、amd_mfma(由 AMD MFMA 矩阵核心产生)和 amd_wmma(由 AMD WMMA 矩阵核心产生)。正如在近期 Triton 大会上所宣布的那样,这种布局表示将过渡到一种新的线性布局,以统一后端内部及后端之间的布局。从 Triton-GPUIR 到 LLVM-IR 的阶段将 Triton-GPUIR 转换为 LLVM 的表示形式。目前,Triton 拥有对 NVIDIA 和 AMD 设备的第三方后端支持,但其他设备的支持正处于开源社区的积极开发中。

为了说明,下面展示了 LLVM-IR 向量加法参数的一小部分:

  %19 = extractvalue { i32, i32, i32, i32 } %18, 0, !dbg !16
  %39 = extractvalue { i32, i32, i32, i32 } %38, 0, !dbg !18
  %23 = bitcast i32 %19 to float, !dbg !16
  %43 = bitcast i32 %39 to float, !dbg !18
  %56 = fadd float %23, %43, !dbg !19

在进行一些指针运算和内联汇编调用以从全局内存获取数据后,向量元素被提取并转换为正确的类型。最后,它们相加,随后通过内联汇编表达式写回全局内存。

Triton 编译过程的最后阶段将 LLVM-IR 降级为特定设备的二进制文件。对于向量加法示例,在 NVIDIA GPU 上,下一个中间表示是 PTX(并行线程执行)。自 CUDA 1.0 发布以来,低级 PTX 语法指定了 NVIDIA 设备在线程级别的执行方式。有关 PTX 的深入指南,请参阅 NVIDIA 文档。在向量加法中,内核参数从主机传递到内核,分配地址,并使用 mov 指令促进线程级数据访问,最终通过 add.f32(如下例所示)表示元素加法调用。

	add.f32 	%f17, %f1, %f9// add type float32, output register, input register for x, input register for y 

Triton 编译器协调最后阶段,由不同的硬件后端管理汇编代码如何编译成二进制文件。至此,Triton 内核即可使用。

总结

Triton 提供了一种高级抽象,用于为不同类型的硬件编程和编译内核。在这篇文章中,我们重点介绍了 Triton 代码表示和 Triton 编译器的不同阶段。有关包含自定义 Triton 内核或使用 Triton 内核加速不同工作负载的详细信息,请查看 PyTorch Triton 教程、关于 Triton GPTQ 内核使用 Triton 进行 Llama3 FP8 推理LLM 的无 CUDA 推理 的博客文章,或者 PyTorch 2.2 关于 Triton 代码生成的章节