• 文档 >
  • 通过 Triton 自定义 GPU 内核
快捷方式

通过 Triton 自定义 GPU 内核

PyTorch/XLA 现在支持 Triton 内核,从而能够在 GPU 上实现高性能深度学习模型执行。Triton 是一种专门用于 GPU 编程的语言和编译器,使开发人员能够编写自定义内核,以充分利用 GPU 的潜力来执行深度学习模型中的各种操作。

给定一个如下定义的 Triton 内核

@triton.jit
def add_kernel(
    x_ptr,  # *Pointer* to first input vector.
    y_ptr,  # *Pointer* to second input vector.
    output_ptr,  # *Pointer* to output vector.
    n_elements,  # Size of the vector.
    BLOCK_SIZE: tl.constexpr,  # Number of elements each program should process.
    # NOTE: `constexpr` so it can be used as a shape value.
):
  # Triton add kernel from https://github.com/openai/triton/blob/main/python/tutorials/01-vector-add.py#L28
  pid = tl.program_id(axis=0)
  block_start = pid * BLOCK_SIZE
  offsets = block_start + tl.arange(0, BLOCK_SIZE)
  mask = offsets < n_elements
  x = tl.load(x_ptr + offsets, mask=mask)
  y = tl.load(y_ptr + offsets, mask=mask)
  output = x + y
  tl.store(output_ptr + offsets, output, mask=mask)

我们可以运行此内核,使其成为 PyTorch/XLA 执行图的一部分,如下所示

import torch

import torch_xla.experimental.triton as xla_triton
import torch_xla

import triton
import triton.language as tl

size = 16
x = torch.arange(size, dtype=torch.int64).to("xla")
y = torch.arange(size, dtype=torch.int64).to("xla")
output = torch.empty_like(x)
block_size = 8
grid = (triton.cdiv(size, block_size),)

# triton_call takes the same arguments as the triton.jit function, in addition
to the kernel itself and the grid that is used to execute the kernel.
All the tl.constexpr terms are passed as kwargs at the end.
payload = xla_triton.triton_call(
    x, y, output, size, kernel=add_kernel, grid=grid, BLOCK_SIZE=block_size)

# To make the triton kernel, a part of the PyTorch/XLA graph, we create a
# custom call node with the expected inputs, payload from triton_call,
# the output shapes and output dtypes. The payload already contains information
# regarding how the GPU buffers will be loaded when this node is executed.
output = torch_xla._XLAC._xla_gpu_custom_call([x, y], payload,
                                                [output.shape], [torch.int64])

对于更复杂的内核,您还可以参考 PyTorch/XLA 中的 Triton Flash Attention 内核测试。

依赖项

Triton 集成依赖于 triton 包才能运行。此代码已在 triton==2.3.0 版本下进行测试。要安装

pip install --no-deps triton==2.3.0

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题解答

查看资源