• 教程 >
  • 使用 torch.compile 使用用户定义的 Triton 内核
快捷方式

使用 torch.compile 使用用户定义的 Triton 内核

作者: Oguz Ulgen

用户定义的 Triton 内核可用于优化模型计算的特定部分。这些内核是用 Triton 语言编写的,该语言旨在简化实现硬件峰值性能的过程。通过将用户定义的 Triton 内核与 torch.compile 结合使用,您可以将这些优化的计算集成到 PyTorch 模型中,从而有可能实现显著的性能提升。

此食谱演示了如何将用户定义的 Triton 内核与 torch.compile 结合使用。

先决条件

在开始此食谱之前,请确保您具备以下条件

import torch
from torch.utils._triton import has_triton

基本用法

在此示例中,我们将使用 Triton 文档中的一个简单的向量加法内核与 torch.compile 结合使用。作为参考,请参阅 Triton 文档

if not has_triton():
    print("Skipping because triton is not supported on this device.")
else:
    import triton
    from triton import language as tl

    @triton.jit
    def add_kernel(
        in_ptr0,
        in_ptr1,
        out_ptr,
        n_elements,
        BLOCK_SIZE: "tl.constexpr",
    ):
        pid = tl.program_id(axis=0)
        block_start = pid * BLOCK_SIZE
        offsets = block_start + tl.arange(0, BLOCK_SIZE)
        mask = offsets < n_elements
        x = tl.load(in_ptr0 + offsets, mask=mask)
        y = tl.load(in_ptr1 + offsets, mask=mask)
        output = x + y
        tl.store(out_ptr + offsets, output, mask=mask)

    @torch.compile(fullgraph=True)
    def add_fn(x, y):
        output = torch.zeros_like(x)
        n_elements = output.numel()
        grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
        add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=4)
        return output

    x = torch.randn(4, device="cuda")
    y = torch.randn(4, device="cuda")
    out = add_fn(x, y)
    print(f"Vector addition of\nX:\t{x}\nY:\t{y}\nis equal to\n{out}")
Vector addition of
X:      tensor([ 0.1940,  2.1614, -0.1721,  0.8491], device='cuda:0')
Y:      tensor([ 0.1391, -0.1082, -0.7174,  0.7566], device='cuda:0')
is equal to
tensor([ 0.3332,  2.0532, -0.8895,  1.6057], device='cuda:0')

高级用法

Triton 的自动调整功能是一个强大的工具,它会自动优化 Triton 内核的配置参数。它会探索一系列可能的配置,并选择最适合您的特定用例的配置,从而提供最佳性能。

当与 torch.compile 一起使用时,triton.autotune 可以帮助确保您的 PyTorch 模型以尽可能高效的方式运行。以下是如何使用 torch.compiletriton.autotune 的示例。

注意

torch.compile 仅支持 triton.autotune 的配置和关键参数。

if not has_triton():
    print("Skipping because triton is not supported on this device.")
else:
    import triton
    from triton import language as tl

    @triton.autotune(
        configs=[
            triton.Config({"BLOCK_SIZE": 4}, num_stages=3, num_warps=8),
            triton.Config({"BLOCK_SIZE": 4}, num_stages=4, num_warps=4),
            triton.Config({"BLOCK_SIZE": 2}, num_stages=3, num_warps=8),
            triton.Config({"BLOCK_SIZE": 2}, num_stages=4, num_warps=4),
        ],
        key=[],
    )
    @triton.jit
    def add_kernel_autotuned(
        in_ptr0,
        in_ptr1,
        out_ptr,
        n_elements,
        BLOCK_SIZE: "tl.constexpr",
    ):
        pid = tl.program_id(axis=0)
        block_start = pid * BLOCK_SIZE
        offsets = block_start + tl.arange(0, BLOCK_SIZE)
        mask = offsets < n_elements
        x = tl.load(in_ptr0 + offsets, mask=mask)
        y = tl.load(in_ptr1 + offsets, mask=mask)
        output = x + y
        tl.store(out_ptr + offsets, output, mask=mask)

    @torch.compile(fullgraph=True)
    def add_fn(x, y):
        output = torch.zeros_like(x)
        n_elements = output.numel()
        grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
        add_kernel_autotuned[grid](x, y, output, n_elements)
        return output

    x = torch.randn(4, device="cuda")
    y = torch.randn(4, device="cuda")
    out = add_fn(x, y)
    print(f"Vector addition of\nX:\t{x}\nY:\t{y}\nis equal to\n{out}")
Vector addition of
X:      tensor([-0.5187,  1.2268,  0.6255, -0.9117], device='cuda:0')
Y:      tensor([-0.6974, -1.8688, -0.8832, -1.6627], device='cuda:0')
is equal to
tensor([-1.2161, -0.6421, -0.2577, -2.5744], device='cuda:0')

组合性和限制

截至 PyTorch 2.3,对 torch.compile 中用户定义的 Triton 内核的支持包括动态形状、torch.autograd.Function、JIT 感应器和 AOT 感应器。您可以将这些功能结合起来构建复杂的高性能模型。

但是,需要注意一些限制

  • 张量子类:目前,不支持张量子类和其他高级功能。

  • Triton 功能:虽然 triton.heuristics 可以独立使用或在 triton.autotune 之前使用,但它不能在 triton.autotune 之后使用。这意味着,如果要将 triton.heuristicstriton.autotune 结合使用,则必须先使用 triton.heuristics

结论

在本教程中,我们探讨了如何将用户定义的 Triton 内核与 torch.compile 一起使用。我们深入研究了简单向量加法内核的基本用法以及涉及 Triton 自动调整功能的高级用法。我们还讨论了用户定义的 Triton 内核与其他 PyTorch 功能的组合性,并强调了一些当前的限制。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源