• 文档 >
  • 使用自定义转换器重载 Torch-TensorRT 转换器
快捷键

使用自定义转换器重载 Torch-TensorRT 转换器

如果出于某种原因,您想更改特定 PyTorch 运算到 TensorRT 的转换行为,您可以通过编写自定义转换器并重载 Torch-TensorRT 的转换器来实现。这可能是因为您想使用自定义内核而不是 TensorRT 的内核,或者因为您想在 TensorRT 中使用与 Torch-TensorRT 通常使用的层实现不同的实现。

在本教程中,我们将演示如何使用自定义转换器重载 Torch-TensorRT 对 torch.nn.functional.gelu 运算到 TensorRT 的转换,该自定义转换器使用 GeLU 层的不同实现。

import logging
import sys

import torch
import torch_tensorrt

GeLU 在 PyTorch 中有 2 种模式,一种使用 erf 函数,另一种使用 tanh 近似。TensorRT 本身支持这两种实现作为激活层,但假设我们只想在 TensorRT 中为 tanh 模式使用 GeLU 的自定义实现。

class GeLU(torch.nn.Module):
    def __init__(self, mode="tanh"):
        super().__init__()
        self.mode = mode

    def forward(self, x):
        return torch.nn.functional.gelu(x, approximate=self.mode)


my_mod = GeLU(mode="tanh")
ex_input = torch.randn(2, 5).to("cuda")

作为基线,我们可以将标准 Torch-TensorRT GeLU 转换器(在 tanh 近似模式下)与我们的模块一起使用。

my_standard_gelu = torch_tensorrt.compile(
    my_mod, arg_inputs=(ex_input,), min_block_size=1
)
print(my_standard_gelu.graph)
print(my_standard_gelu(ex_input))

编写自定义转换器

转换器是函数,它接受 PyTorch 图中特定 PyTorch 运算的实例,并将其转换为正在构建的 TensorRT 图中的等效 TensorRT 运算集。它们使用 @torch_tensorrt.dynamo.conversion.dynamo_tensorrt_converter 装饰器在 Torch-TensorRT 中注册。在代码级别,转换器接受当前的转换状态 (ConversionCtx)、图中要转换的下一个运算符以及该节点的参数,并返回该运算的占位符输出,同时作为副作用将必要的 TensorRT 层插入到 TensorRT 网络中。

from typing import Dict, Sequence, Tuple, Union

from torch.fx.node import Argument, Node, Target
from torch_tensorrt.dynamo import CompilationSettings
from torch_tensorrt.dynamo.conversion import ConversionContext

import tensorrt as trt

转换器元数据

@torch_tensorrt.dynamo.conversion.dynamo_tensorrt_converter(
    # The PyTorch operation to convert, when this operation is encountered, this converter will be called
    torch.ops.aten.gelu.default,
    # Validators are functions that determine that given a specific node, if it can be converted by the converter
    capability_validator=lambda node, settings: (
        "approximate" in node.kwargs and node.kwargs["approximate"] == "tanh"
    ),
    # Can this converter be used in cases where the input shapes are dynamic
    supports_dynamic_shapes=True,
    # Set the priority of the converter to supersede the default one
    priority=torch_tensorrt.dynamo.conversion.ConverterPriority.HIGH,
)

对于定义转换器的装饰器,有一个必需的参数和一些可选参数。所有转换器都需要一个它们将针对的目标运算符,其思想是当图中存在 torch.ops.aten.gelu.default 的实例时,将调用此转换器。

在目标运算符之后,您可以提供额外的元数据,这些元数据定义了转换器的功能以及相对于目标问题的其他可能转换器的优先级

定义转换器功能的主要工具是 capability_validator 参数,它是一个 lambda 函数,接受图中的特定节点以及用户编译设置,并返回一个布尔值,指示该转换器是否可以用于该节点。此验证器函数在图分区阶段之前针对转换器目标运算的每个实例运行。在此阶段没有验证器通过的转换器的节点将在运行时在 PyTorch 中执行。这对于您只想在特定情况下使用自定义转换器的情况很有用,例如在我们的示例中,我们只想在 approximate == "tanh" 时使用我们的转换器。

与验证器不同的是 supports_dynamic_shapes 参数,它是一个布尔值,指示转换器是否可以用于输入形状是动态的情况。如果设置为 False,则在用户提供的输入是动态的情况下,将禁用此转换器。如果没有其他支持动态形状的替代方案,则该运算将在 PyTorch 中运行。

最后是 priority 参数,它是 torch_tensorrt.dynamo.conversion.ConverterPriority 类中的枚举,用于定义转换器的优先级。两个选项是 HIGHSTANDARD。使用 STANDARD 注册的转换器将附加到给定运算的转换器列表,而使用 HIGH 注册的转换器将预先添加到列表中。候选转换器根据其适用性按此优先级顺序进行评估,并使用第一个通过验证器的转换器。

转换器实现

转换器函数本身接受以下参数:当前转换上下文、目标运算符、目标运算符的参数、目标运算符的关键字参数以及目标运算符的名称。参数可以是任何 Python 原语、torch.Tensornp.ArraysITensor 对象。转换器函数应主要以 TensorRT ITensor 的形式返回目标运算符的输出。这些输入和输出应与目标 PyTorch 运算符的模式相对应,该模式可在此处找到 https://pytorch.ac.cn/docs/main/torch.compiler_ir.html

由于 Torch-TensorRT 涵盖了核心 ATen opset,它已经将许多常见的底层操作抽象为辅助函数,这些辅助函数可用于构建 TensorRT 网络。这允许开发者避免直接创建 TensorRT 层的样板代码,而专注于转换的高级逻辑。辅助函数位于 torch_tensorrt.dynamo.conversion.impl 模块中,旨在可组合且可与原始 TensorRT 实现互操作。在本例中,我们将使用 Torch-TensorRT 的 muladdtanh 函数(来自 impl)来实现我们的替代 GeLU 层。

def aten_ops_gelu(
    ctx: ConversionContext,
    target: Target,
    args: Tuple[Argument, ...],
    kwargs: Dict[str, Argument],
    name: str,
) -> Union[trt.ITensor, Sequence[trt.ITensor]]:
    # The schema for torch.ops.aten.gelu.default is gelu(Tensor self, *, str approximate=’none’) -> Tensor

    from torch_tensorrt.dynamo import SourceIR
    from torch_tensorrt.dynamo.conversion import impl

    # Cheap way to allow layer names to be unqiue
    op_count = 0

    def get_op_count():
        nonlocal op_count
        op_count += 1
        return op_count

    mul = lambda x, y: impl.elementwise.mul(
        ctx,
        target,
        name=f"mul_{get_op_count()}",
        source_ir=SourceIR.ATEN,
        lhs_val=x,
        rhs_val=y,
    )
    add = lambda x, y: impl.elementwise.add(
        ctx,
        target,
        name=f"add_{get_op_count()}",
        source_ir=SourceIR.ATEN,
        lhs_val=x,
        rhs_val=y,
    )
    tanh = lambda x: impl.activation.tanh(
        ctx, target, name=f"tanh_{get_op_count()}", source_ir=SourceIR.ATEN, input_val=x
    )

    # So we know that our custom converter is being run instead of the standard one
    print("\n\n---------------------------")
    print("Using custom GeLU converter")
    print("---------------------------\n\n")

    x_7 = mul(args[0], 0.5)
    x_8 = mul(args[0], 0.79788456080000003)
    x_9 = mul(args[0], 0.044714999999999998)
    x_10 = mul(x_9, args[0])
    x_11 = add(x_10, 1.0)
    x_12 = mul(x_8, x_11)
    x_13 = tanh(x_12)
    x_14 = add(x_13, 1.0)
    x_15 = mul(x_7, x_14)

    return x_15

使用我们的自定义转换器

我们现在可以重新编译并看到我们的自定义转换器正在被调用以将 GeLU 转换为 TensorRT。

my_custom_gelu = torch_tensorrt.compile(
    my_mod, arg_inputs=(ex_input,), min_block_size=1
)

print(my_custom_gelu.graph)
print(my_custom_gelu(ex_input))

我们可以验证我们的实现与 tanh 近似的 TensorRT 实现相匹配。

print(
    f"tanh approximations are close: {torch.allclose(my_standard_gelu(ex_input), my_custom_gelu(ex_input))}"
)

最后,我们要验证在 approximate 参数未设置为 tanh 的情况下,我们的自定义转换器未被使用。

my_mod_erf = GeLU(mode="none")
my_gelu_erf = torch_tensorrt.compile(
    my_mod_erf, arg_inputs=(ex_input,), min_block_size=1
)

请注意,我们没有看到来自自定义转换器的打印语句,表明它没有被使用。但是,查看图表,我们仍然可以看到创建了一个 TensorRT 引擎来运行 GeLU 运算。在这种情况下,我们的自定义转换器的验证器返回 False,因此转换系统继续列表中的下一个转换器,即标准 GeLU 转换器,并使用该转换器来转换运算。

print(my_gelu_erf.graph)
print(my_gelu_erf(ex_input))

脚本的总运行时间: (0 分钟 0.000 秒)

图库由 Sphinx-Gallery 生成

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得问题解答

查看资源