通过 Triton 自定义 GPU 内核¶
PyTorch/XLA 现在支持 Triton 内核,从而能够在 GPU 上实现高性能深度学习模型执行。Triton 是一种专门用于 GPU 编程的语言和编译器,使开发人员能够编写自定义内核,以充分利用 GPU 的潜力来执行深度学习模型中的各种操作。
给定一个如下定义的 Triton 内核
@triton.jit
def add_kernel(
x_ptr, # *Pointer* to first input vector.
y_ptr, # *Pointer* to second input vector.
output_ptr, # *Pointer* to output vector.
n_elements, # Size of the vector.
BLOCK_SIZE: tl.constexpr, # Number of elements each program should process.
# NOTE: `constexpr` so it can be used as a shape value.
):
# Triton add kernel from https://github.com/openai/triton/blob/main/python/tutorials/01-vector-add.py#L28
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
output = x + y
tl.store(output_ptr + offsets, output, mask=mask)
我们可以运行此内核,使其成为 PyTorch/XLA 执行图的一部分,如下所示
import torch
import torch_xla.experimental.triton as xla_triton
import torch_xla
import triton
import triton.language as tl
size = 16
x = torch.arange(size, dtype=torch.int64).to("xla")
y = torch.arange(size, dtype=torch.int64).to("xla")
output = torch.empty_like(x)
block_size = 8
grid = (triton.cdiv(size, block_size),)
# triton_call takes the same arguments as the triton.jit function, in addition
to the kernel itself and the grid that is used to execute the kernel.
All the tl.constexpr terms are passed as kwargs at the end.
payload = xla_triton.triton_call(
x, y, output, size, kernel=add_kernel, grid=grid, BLOCK_SIZE=block_size)
# To make the triton kernel, a part of the PyTorch/XLA graph, we create a
# custom call node with the expected inputs, payload from triton_call,
# the output shapes and output dtypes. The payload already contains information
# regarding how the GPU buffers will be loaded when this node is executed.
output = torch_xla._XLAC._xla_gpu_custom_call([x, y], payload,
[output.shape], [torch.int64])
对于更复杂的内核,您还可以参考 PyTorch/XLA 中的 Triton Flash Attention 内核测试。