• 教程 >
  • 使用用户定义的 Triton 内核与 torch.compile
快捷方式

使用用户定义的 Triton 内核与 torch.compile

创建于: 2024年4月19日 | 最后更新于: 2025年3月7日 | 最后验证于: 2024年11月5日

作者: Oguz Ulgen

用户定义的 Triton 内核可用于优化模型计算的特定部分。这些内核使用 Triton 语言编写,该语言旨在更容易实现峰值硬件性能。通过将用户定义的 Triton 内核与 torch.compile 结合使用,您可以将这些优化计算集成到您的 PyTorch 模型中,从而可能显著提升性能。

本实用示例演示了如何将用户定义的 Triton 内核与 torch.compile 结合使用。

前提条件

在开始本实用示例之前,请确保您具备以下条件

import torch
from torch.utils._triton import has_triton

基本用法

在此示例中,我们将使用 Triton 文档中的一个简单向量加法内核与 torch.compile。有关参考,请参见 Triton 文档

if not has_triton():
    print("Skipping because triton is not supported on this device.")
else:
    import triton
    from triton import language as tl

    @triton.jit
    def add_kernel(
        in_ptr0,
        in_ptr1,
        out_ptr,
        n_elements,
        BLOCK_SIZE: "tl.constexpr",
    ):
        pid = tl.program_id(axis=0)
        block_start = pid * BLOCK_SIZE
        offsets = block_start + tl.arange(0, BLOCK_SIZE)
        mask = offsets < n_elements
        x = tl.load(in_ptr0 + offsets, mask=mask)
        y = tl.load(in_ptr1 + offsets, mask=mask)
        output = x + y
        tl.store(out_ptr + offsets, output, mask=mask)

    @torch.compile(fullgraph=True)
    def add_fn(x, y):
        output = torch.zeros_like(x)
        n_elements = output.numel()
        grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
        add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=4)
        return output

    x = torch.randn(4, device="cuda")
    y = torch.randn(4, device="cuda")
    out = add_fn(x, y)
    print(f"Vector addition of\nX:\t{x}\nY:\t{y}\nis equal to\n{out}")
Vector addition of
X:      tensor([ 0.1940,  2.1614, -0.1721,  0.8491], device='cuda:0')
Y:      tensor([ 0.1391, -0.1082, -0.7174,  0.7566], device='cuda:0')
is equal to
tensor([ 0.3332,  2.0532, -0.8895,  1.6057], device='cuda:0')

高级用法

Triton 的 autotune 功能是一个强大的工具,可自动优化 Triton 内核的配置参数。它会探索一系列可能的配置,并选择能为您的特定用例提供最佳性能的配置。

当与 torch.compile 一起使用时,triton.autotune 可以帮助确保您的 PyTorch 模型尽可能高效地运行。以下是使用 torch.compiletriton.autotune 的示例。

注意

torch.compile 仅支持 triton.autotune 的 configs 和 key arguments。

if not has_triton():
    print("Skipping because triton is not supported on this device.")
else:
    import triton
    from triton import language as tl

    @triton.autotune(
        configs=[
            triton.Config({"BLOCK_SIZE": 4}, num_stages=3, num_warps=8),
            triton.Config({"BLOCK_SIZE": 4}, num_stages=4, num_warps=4),
            triton.Config({"BLOCK_SIZE": 2}, num_stages=3, num_warps=8),
            triton.Config({"BLOCK_SIZE": 2}, num_stages=4, num_warps=4),
        ],
        key=[],
    )
    @triton.jit
    def add_kernel_autotuned(
        in_ptr0,
        in_ptr1,
        out_ptr,
        n_elements,
        BLOCK_SIZE: "tl.constexpr",
    ):
        pid = tl.program_id(axis=0)
        block_start = pid * BLOCK_SIZE
        offsets = block_start + tl.arange(0, BLOCK_SIZE)
        mask = offsets < n_elements
        x = tl.load(in_ptr0 + offsets, mask=mask)
        y = tl.load(in_ptr1 + offsets, mask=mask)
        output = x + y
        tl.store(out_ptr + offsets, output, mask=mask)

    @torch.compile(fullgraph=True)
    def add_fn(x, y):
        output = torch.zeros_like(x)
        n_elements = output.numel()
        grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
        add_kernel_autotuned[grid](x, y, output, n_elements)
        return output

    x = torch.randn(4, device="cuda")
    y = torch.randn(4, device="cuda")
    out = add_fn(x, y)
    print(f"Vector addition of\nX:\t{x}\nY:\t{y}\nis equal to\n{out}")
Vector addition of
X:      tensor([-0.5187,  1.2268,  0.6255, -0.9117], device='cuda:0')
Y:      tensor([-0.6974, -1.8688, -0.8832, -1.6627], device='cuda:0')
is equal to
tensor([-1.2161, -0.6421, -0.2577, -2.5744], device='cuda:0')

可组合性

用户定义的 Triton 内核并非自动支持所有 PyTorch 子系统。这可以在以下用例中看到

  • 添加 CPU 回退

  • 添加 FlopCounter 公式

  • 与张量子类组合

要与附加的 PyTorch 子系统组合,请使用 torch.library.triton_op

triton_op 是一种结构化的方式,用于定义由一个或多个 Triton 内核支持的自定义运算符:与常规自定义运算符 (torch.library.custom_op) 一样,您可以通过 torch.library 指定与 PyTorch 子系统的交互。然而,与 torch.library.custom_op 不同的是,torch.library.custom_op 会创建相对于 torch.compile 的不透明可调用对象,而 torch.compile 会跟踪到 triton_op 内部以应用优化。

下图显示了在将 Triton 内核与 PyTorch 集成时应使用哪种 API。

Triton 内核 (无显式 torch.library 包装器)

torch.library.triton_op

torch.library.custom_op

支持推理

支持训练

在大多数情况下

支持 torch.compile

支持 torch.compile(fullgraph=True)

在大多数情况下

在大多数情况下

在所有情况下

torch.compile 是否跟踪到实现内部?

支持 AOTInductor

支持 FlopCounterMode、CPU 回退、张量子类等 PyTorch 子系统

使用 triton_op 包装 Triton 内核

使用 torch.library.triton_op 包装可能调用一个或多个 Triton 内核的函数。使用 torch.library.wrap_triton 包装对 Triton 内核的调用。

from torch.library import triton_op, wrap_triton

@triton_op("mylib::mysin", mutates_args={})
def mysin(x: torch.Tensor) -> torch.Tensor:
    out = torch.empty_like(x)
    n_elements = x.numel()
    wrap_triton(sin_kernel)[(n_elements,)](x, out, n_elements, BLOCK_SIZE=4)
    return out

@triton.jit
def sin_kernel(
    in_ptr0,
    out_ptr,
    n_elements,
    BLOCK_SIZE: "tl.constexpr",
):
    pid = tl.program_id(axis=0)
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements
    x = tl.load(in_ptr0 + offsets, mask=mask)
    output = tl.sin(x)
    tl.store(out_ptr + offsets, output, mask=mask)

def sin_triton(x):
    out = torch.empty_like(x)
    n_elements = x.numel()
    sin_kernel[(n_elements,)](x, out, n_elements, BLOCK_SIZE=4)
    return out

您可以通过以下两种方式之一调用 triton_op

x = torch.randn(3, device="cuda")
y = mysin(x)
z = torch.ops.mylib.mysin.default(x)

assert torch.allclose(y, x.sin())
assert torch.allclose(z, x.sin())

结果 triton_op 可与 torch.compileAOTInductor 一起使用。

y = torch.compile(mysin)(x)
assert torch.allclose(y, x.sin())

添加训练支持

使用 register_autogradtriton_op 添加一个自动微分公式。优先使用此方法,而不是使用 torch.autograd.Function (后者与 torch.compile 在可组合性方面存在各种问题)。

def backward(ctx, grad):
    x, = ctx.saved_tensors
    return grad * x.cos()

def setup_context(ctx, inputs, output):
    x, = inputs
    ctx.save_for_backward(x)

mysin.register_autograd(backward, setup_context=setup_context)

请注意,反向传播必须是 PyTorch 理解的运算符的组合。如果您希望反向传播调用 Triton 内核,那么这些内核也必须包装在 triton_op 中。

@triton.jit
def cos_kernel(
    in_ptr0,
    out_ptr,
    n_elements,
    BLOCK_SIZE: "tl.constexpr",
):
    pid = tl.program_id(axis=0)
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements
    x = tl.load(in_ptr0 + offsets, mask=mask)
    output = tl.cos(x)
    tl.store(out_ptr + offsets, output, mask=mask)

@triton_op("mylib::mycos", mutates_args={})
def mycos(x: torch.Tensor) -> torch.Tensor:
    out = torch.empty_like(x)
    n_elements = x.numel()
    wrap_triton(cos_kernel)[(n_elements,)](x, out, n_elements, BLOCK_SIZE=4)
    return out

def backward(ctx, grad):
    x, = ctx.saved_tensors
    return grad * mycos(x)

def setup_context(ctx, inputs, output):
    x, = inputs
    ctx.save_for_backward(x)

mysin.register_autograd(backward, setup_context=setup_context)

添加 CPU 回退

Triton 内核不在 CPU 上运行。使用 register_kerneltriton_op 添加一个 CPU (或任何其他设备) 回退。

@mysin.register_kernel("cpu")
def _(x):
    return torch.sin(x)

x = torch.randn(3)
y = mysin(x)
assert torch.allclose(y, x.sin())

回退必须由 PyTorch 运算符组成。

添加 FlopCounter 公式

要指定 Triton 内核在 PyTorch 的 flop 计数器下报告多少浮点运算 (flops),请使用 register_flop_formula

from torch.utils.flop_counter import FlopCounterMode, register_flop_formula

@register_flop_formula(torch.ops.mylib.mysin)
def _(x_shape):
    numel = 1
    for s in x_shape:
        numel *= s
    return numel

x = torch.randn(3, device="cuda")

FlopCounterMode 需要 tabulate。在运行以下代码之前,请确保您已安装 tabulate,或者通过运行 pip install tabulate 进行安装。

>>> with FlopCounterMode() as flop_counter:
>>>     y = mysin(x)

限制

截至 PyTorch 2.3,torch.compile 对用户定义的 Triton 内核的支持包括动态形状、torch.autograd.Function、JIT inductor 和 AOT inductor。您可以将这些功能结合使用来构建复杂、高性能的模型。

PyTorch 2.6 添加了 torch.library.triton_op,它增加了对张量子类和其他高级功能中用户定义的 Triton 内核的支持。

但是,需要注意一些限制

  • Triton 功能: 虽然 triton.heuristics 可以单独使用或在 triton.autotune 之前使用,但不能在 triton.autotune 之后使用。这意味着如果 triton.heuristicstriton.autotune 需要一起使用,则必须先使用 triton.heuristics

结论

在本实用示例中,我们探讨了如何将用户定义的 Triton 内核与 torch.compile 结合使用。我们深入研究了简单向量加法内核的基本用法,以及涉及 Triton 的 autotune 功能的高级用法。我们还讨论了用户定义的 Triton 内核与其他 PyTorch 功能的可组合性,并强调了一些当前的限制。

文档

查阅 PyTorch 的完整开发者文档

查看文档

教程

获取针对初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获取问题解答

查看资源