Triton 开源编程语言和编译器提供了一种基于 Python 的高级方法来创建高效的 GPU 代码。在这篇博客中,我们重点介绍 Triton 程序如何编译及其中间表示的底层细节。关于 Triton 的介绍,请读者参考这篇博客。
Triton 语言和编译
Triton 编程语言支持不同类型的现代 GPU,并遵循分块编程方法。举例来说,我们将稍作修改,参照 Triton 向量加法教程。向量加法内核和辅助函数定义如下:
import torch
import triton
import triton.language as tl
@triton.jit
def add_kernel(x_ptr, # *Pointer* to first input vector.
y_ptr, # *Pointer* to second input vector.
output_ptr, # *Pointer* to output vector.
n_elements,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
output = x + y
tl.store(output_ptr + offsets, output, mask=mask)
def add(x: torch.Tensor, y: torch.Tensor):
output = torch.empty_like(x)
assert x.is_cuda and y.is_cuda and output.is_cuda
n_elements = output.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )
triton_kernel=add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
torch.cuda.synchronize()
# Save compilation stages - some of the stages identified here are specific to NVIDIA devices:
with open('triton_IR.txt', 'w') as f:
print(triton_kernel.asm['ttir'], file=f)
with open('triton_TTGIR.txt', 'w') as f:
print(triton_kernel.asm['ttgir'], file=f)
with open('triton_LLVMIR.txt', 'w') as f:
print(triton_kernel.asm['llir'], file=f)
with open('triton_PTX.ptx', 'w') as f:
print(triton_kernel.asm['ptx'], file=f)
with open('triton_cubin.txt', 'w') as f:
print(triton_kernel.asm['cubin'], file=f)
return output
torch.manual_seed(0)
size = 98432
x = torch.rand(size, device='cuda')
y = torch.rand(size, device='cuda')
output_torch = x + y
output_triton = add(x, y)
print(output_torch)
print(output_triton)
print(f'The maximum difference between torch and triton is '
f'{torch.max(torch.abs(output_torch - output_triton))}')
Triton 向量加法内核包含 @triton.jit
装饰器。Triton 编译器将编译标记为 @triton.jit
的函数,这些函数会经过多个编译阶段进行降级。辅助函数 add
分配输出张量,计算合适的 GPU 网格大小,并额外保存中间编译阶段。
专注于编译过程,Triton 内核通过下图所示的一系列阶段降级为设备特定的汇编代码。
内核的编译过程首先遍历被修饰的 Python 函数的抽象语法树 (AST) 以创建 Triton 中间表示 (Triton-IR)。Triton-IR 是一种未经优化、机器无关的中间表示。它引入了瓦片级编程要求,并基于开源 LLVM 编译器项目。接下来,Triton 编译器优化并将 Triton-IR 转换为 Triton-GPU IR (Triton-TTGIR) 阶段,然后再转换为 LLVM-IR。Triton-IR 和 Triton-GPUIR 表示都写成 MLIR 方言,MLIR 是 LLVM 的一个子项目,旨在改进异构硬件的编译。
对于 Triton 向量加法教程内核,Triton IR 代码片段示例如下:
module {
tt.func public @add_kernel(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("/u/saraks/triton_blog/01-vector-add.py":28:0), %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("/u/saraks/triton_blog/01-vector-add.py":28:0), %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("/u/saraks/triton_blog/01-vector-add.py":28:0), %arg3: i32 {tt.divisibility = 16 : i32} loc("/u/saraks/triton_blog/01-vector-add.py":28:0)) attributes {noinline = false} {
%c1024_i32 = arith.constant 1024 : i32 loc(#loc1)
%0 = tt.get_program_id x : i32 loc(#loc2)
%1 = arith.muli %0, %c1024_i32 : i32 loc(#loc3)
%2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> loc(#loc4)
%3 = tt.splat %1 : i32 -> tensor<1024xi32> loc(#loc5)
%4 = arith.addi %3, %2 : tensor<1024xi32> loc(#loc5)
%5 = tt.splat %arg3 : i32 -> tensor<1024xi32> loc(#loc6)
%6 = arith.cmpi slt, %4, %5 : tensor<1024xi32> loc(#loc6)
%7 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>> loc(#loc7)
%8 = tt.addptr %7, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32> loc(#loc7)
%9 = tt.load %8, %6 : tensor<1024x!tt.ptr<f32>> loc(#loc8)
%10 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>> loc(#loc9)
%11 = tt.addptr %10, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32> loc(#loc9)
%12 = tt.load %11, %6 : tensor<1024x!tt.ptr<f32>> loc(#loc10)
%13 = arith.addf %9, %12 : tensor<1024xf32> loc(#loc11)
%14 = tt.splat %arg2 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>> loc(#loc12)
%15 = tt.addptr %14, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32> loc(#loc12)
tt.store %15, %13, %6 : tensor<1024x!tt.ptr<f32>> loc(#loc13)
tt.return loc(#loc14)
} loc(#loc)
} loc(#loc)
请注意,Triton 内核中的主要函数现在表示为
Triton 内核 | Triton IR |
x = tl.load(x_ptr + offsets, mask=mask) | %9 = tt.load %8, %6 : tensor<1024x!tt.ptr<f32>> loc(#loc8) |
y = tl.load(y_ptr + offsets, mask=mask) | %12 = tt.load %11, %6 : tensor<1024x!tt.ptr<f32>> loc(#loc10) |
output = x + y | %13 = arith.addf %9, %12 : tensor<1024xf32> loc(#loc11) |
tl.store(output_ptr + offsets, output, mask=mask) | tt.store %15, %13, %6 : tensor<1024x!tt.ptr<f32>> loc(#loc13) |
在 Triton IR 阶段,%arg0: !tt.ptr<f32>
和后面的张量引用显示中间表示已经按数据类型进行了特化。
我们在配备 CUDA 12.2、Python 3.11.9 和 PyTorch 2.4.1(以及 PyTorch 安装的默认 Triton 版本)的 Tesla V100-SXM2-32GB GPU 上运行此示例。在此设备上,简单的向量加法具有以下 Triton GPU IR 代码片段(为清晰起见省略了部分行):
#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:70", "triton_gpu.threads-per-warp" = 32 : i32} {
tt.func public @add_kernel(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}
⋮
%9 = tt.load %8, %6 : tensor<1024x!tt.ptr<f32>, #blocked> loc(#loc8)
⋮
%12 = tt.load %11, %6 : tensor<1024x!tt.ptr<f32>, #blocked> loc(#loc10)
%13 = arith.addf %9, %12 : tensor<1024xf32, #blocked> loc(#loc11)
⋮
tt.store %15, %13, %6 : tensor<1024x!tt.ptr<f32>, #blocked> loc(#loc13)
⋮
} loc(#loc)
} loc(#loc)
在此阶段,包含了一些硬件特定信息。例如,包含计算能力以及张量如何分配到核心和 warp(或 AMD GPU 上的 wavefront)的详细信息。在此示例中,张量表示为 #blocked
布局。在此编码中,每个 warp 拥有张量的一个连续部分。目前,其他可能的内存优化包括 slice
(沿某个维度重构和分布张量)、dot_op
(块矩阵乘法的优化布局)、shared
(表示 GPU 共享内存)、nvidia_mma
(由 NVIDIA 张量核心产生)、amd_mfma
(由 AMD MFMA 矩阵核心产生)和 amd_wmma
(由 AMD WMMA 矩阵核心产生)等布局。正如最近在 Triton 大会上宣布的那样,这种布局表示将过渡到新的线性布局,以统一后端内部和跨后端的布局。从 Triton-GPUIR 到 LLVM-IR 的阶段将 Triton-GPUIR 转换为 LLVM 的表示形式。目前,Triton 对 NVIDIA 和 AMD 设备有第三方后端支持,但其他设备的支持正在由开源社区积极开发中。
为说明起见,下面展示 LLVM-IR 向量加法参数的一个小子集:
%19 = extractvalue { i32, i32, i32, i32 } %18, 0, !dbg !16
%39 = extractvalue { i32, i32, i32, i32 } %38, 0, !dbg !18
%23 = bitcast i32 %19 to float, !dbg !16
%43 = bitcast i32 %39 to float, !dbg !18
%56 = fadd float %23, %43, !dbg !19
在进行一些指针算术运算并内联汇编调用以从全局内存中检索数据后,向量元素被提取并转换为正确类型。最后,它们相加,随后通过内联汇编表达式写入全局内存。
Triton 编译过程的最后阶段将 LLVM-IR 降级为设备特定的二进制文件。对于向量加法示例,在 NVIDIA GPU 上,下一个中间阶段是 PTX (Parallel Thread Execution)。低级 PTX 语法指定了 NVIDIA 设备在线程级别的执行,始于 CUDA 1.0 版本。有关 PTX 的深入指南,请参阅 NVIDIA 的文档。在向量加法中,内核参数从主机传递到内核,分配地址,并且 mov
指令促进线程级别的数据访问,最终通过 add.f32
表示元素加法调用,例如以下示例:
add.f32 %f17, %f1, %f9// add type float32, output register, input register for x, input register for y
Triton 编译器协调最后阶段,不同的硬件后端管理汇编代码如何编译成二进制文件。Triton 内核现在已可使用。
总结
Triton 提供了一种高级抽象,用于为不同类型的硬件编写和编译内核。在这篇文章中,我们重点介绍了 Triton 代码表示和 Triton 编译器的不同阶段。有关包含自定义 Triton 内核或使用 Triton 内核加速不同工作负载的详细信息,请查看 PyTorch Triton 教程,以及关于 Triton GPTQ 内核、使用 Triton 进行 Llama3 FP8 推理和 用于 LLM 的无 CUDA 推理的博客文章,或 PyTorch 2.2 中关于 Triton 代码生成的部分。