• 文档 >
  • 使用自定义转换器重载 Torch-TensorRT 转换器
快捷方式

使用自定义转换器重载 Torch-TensorRT 转换器

如果您出于某种原因想要更改特定 PyTorch 操作转换为 TensorRT 的转换行为,可以通过编写自定义转换器并重载 Torch-TensorRT 来做到这一点。这可能是出于以下原因:您希望使用自定义内核而不是 TensorRT 的内核,或者您希望在 TensorRT 中使用与 Torch-TensorRT 通常使用的层不同的层实现。

在本教程中,我们将演示如何使用自定义转换器重载 Torch-TensorRT 对 torch.nn.functional.gelu 操作的转换,该转换器使用 GeLU 层的不同实现。

import logging
import sys

import torch
import torch_tensorrt

GeLU 在 PyTorch 中有两种模式,一种使用 erf 函数,另一种使用 tanh 近似值。TensorRT 原生支持两种实现作为激活层,但是假设我们只想在 tanh 模式下在 TensorRT 中使用 GeLU 的自定义实现。

class GeLU(torch.nn.Module):
    def __init__(self, mode="tanh"):
        super().__init__()
        self.mode = mode

    def forward(self, x):
        return torch.nn.functional.gelu(x, approximate=self.mode)


my_mod = GeLU(mode="tanh")
ex_input = torch.randn(2, 5).to("cuda")

作为基线,我们可以使用标准的 Torch-TensorRT GeLU 转换器(在 tanh 近似模式下)与我们的模块一起使用。

my_standard_gelu = torch_tensorrt.compile(
    my_mod, arg_inputs=(ex_input,), min_block_size=1
)
print(my_standard_gelu.graph)
print(my_standard_gelu(ex_input))

编写自定义转换器

转换器是函数,它们接受 PyTorch 图中的 PyTorch 操作的特定实例,并将其转换为正在构建的 TensorRT 图中的等效 TensorRT 操作集。它们使用 @torch_tensorrt.dynamo.conversion.dynamo_tensorrt_converter 装饰器注册到 Torch-TensorRT。在代码级别,转换器接受当前转换状态 (ConversionCtx)、图中要转换的下一个操作符以及该节点的参数,并返回该操作的占位符输出,同时作为副作用将必要的 TensorRT 层插入到 TensorRT 网络中。

from typing import Dict, Sequence, Tuple, Union

from torch.fx.node import Argument, Node, Target
from torch_tensorrt.dynamo import CompilationSettings
from torch_tensorrt.dynamo.conversion import ConversionContext

import tensorrt as trt

转换器元数据

@torch_tensorrt.dynamo.conversion.dynamo_tensorrt_converter(
    # The PyTorch operation to convert, when this operation is encountered, this converter will be called
    torch.ops.aten.gelu.default,
    # Validators are functions that determine that given a specific node, if it can be converted by the converter
    capability_validator=lambda node, settings: (
        "approximate" in node.kwargs and node.kwargs["approximate"] == "tanh"
    ),
    # Can this converter be used in cases where the input shapes are dynamic
    supports_dynamic_shapes=True,
    # Set the priority of the converter to supersede the default one
    priority=torch_tensorrt.dynamo.conversion.ConverterPriority.HIGH,
)

对于定义转换器的装饰器,有一个必需的参数和几个可选参数。所有转换器都需要一个它们将针对的目标操作符,想法是当图中存在 torch.ops.aten.gelu.default 的实例时,将调用此转换器。

在目标操作符之后,您可以提供其他元数据来定义转换器的功能和转换器相对于针对此目标的其他可能转换器的优先级

定义转换器功能的主要工具是 capability_validator 参数,它是一个 lambda 函数,它接受图中的特定节点以及用户编译设置,并返回一个布尔值,指示转换器是否可以用于该节点。此验证器函数在图分区阶段之前针对转换器目标操作符的每个实例运行。在验证器通过此阶段的转换器不存在的节点将在运行时在 PyTorch 中执行。这对于您只想在特定情况下使用自定义转换器的情况很有用,例如在我们的情况下,我们只想在 approximate == "tanh" 时使用我们的转换器。

与验证器不同的是 supports_dynamic_shapes 参数,它是一个布尔值,指示转换器是否可以在输入形状为动态的情况下使用。如果将其设置为 False,则在用户提供的输入为动态的情况下,将禁用此转换器。如果没有支持动态形状的替代方案,则将在 PyTorch 中运行该操作。

最后是 priority 参数,它是一个枚举,来自 torch_tensorrt.dynamo.conversion.ConverterPriority 类,用于定义转换器的优先级。两个选项是 HIGHSTANDARD。使用 STANDARD 注册的转换器将附加到给定操作的转换器列表中,而使用 HIGH 注册的转换器将被预先添加到列表中。按此优先级顺序评估候选转换器以确定其适用性,第一个通过验证器的转换器将被使用。

转换器实现

转换器函数本身接受以下参数:当前转换上下文、目标操作符、目标操作符的参数、目标操作符的关键字参数以及目标操作符的名称。参数可以是任何 Python 原语、torch.Tensornp.ArraysITensor 对象。转换器函数应主要以 TensorRT ITensor 的形式返回目标操作符的输出。这些输入和输出应对应于目标 PyTorch 操作符的模式,可以在此处找到 https://pytorch.ac.cn/docs/main/torch.compiler_ir.html

由于 Torch-TensorRT 涵盖了核心 ATen 操作集,它已经将许多常见的低级操作抽象到帮助函数中,这些函数可用于构建 TensorRT 网络。 这使开发人员无需直接创建 TensorRT 层,而是可以专注于转换的高级逻辑。 这些帮助函数位于 torch_tensorrt.dynamo.conversion.impl 模块中,旨在与原始 TensorRT 实现组合和互操作。 在这种情况下,我们将使用 Torch-TensorRT 的 muladdtanh 函数(来自 impl)来实现我们的替代 GeLU 层。

def aten_ops_gelu(
    ctx: ConversionContext,
    target: Target,
    args: Tuple[Argument, ...],
    kwargs: Dict[str, Argument],
    name: str,
) -> Union[trt.ITensor, Sequence[trt.ITensor]]:
    # The schema for torch.ops.aten.gelu.default is gelu(Tensor self, *, str approximate=’none’) -> Tensor

    from torch_tensorrt.dynamo import SourceIR
    from torch_tensorrt.dynamo.conversion import impl

    # Cheap way to allow layer names to be unqiue
    op_count = 0

    def get_op_count():
        nonlocal op_count
        op_count += 1
        return op_count

    mul = lambda x, y: impl.elementwise.mul(
        ctx,
        target,
        name=f"mul_{get_op_count()}",
        source_ir=SourceIR.ATEN,
        lhs_val=x,
        rhs_val=y,
    )
    add = lambda x, y: impl.elementwise.add(
        ctx,
        target,
        name=f"add_{get_op_count()}",
        source_ir=SourceIR.ATEN,
        lhs_val=x,
        rhs_val=y,
    )
    tanh = lambda x: impl.activation.tanh(
        ctx, target, name=f"tanh_{get_op_count()}", source_ir=SourceIR.ATEN, input_val=x
    )

    # So we know that our custom converter is being run instead of the standard one
    print("\n\n---------------------------")
    print("Using custom GeLU converter")
    print("---------------------------\n\n")

    x_7 = mul(args[0], 0.5)
    x_8 = mul(args[0], 0.79788456080000003)
    x_9 = mul(args[0], 0.044714999999999998)
    x_10 = mul(x_9, args[0])
    x_11 = add(x_10, 1.0)
    x_12 = mul(x_8, x_11)
    x_13 = tanh(x_12)
    x_14 = add(x_13, 1.0)
    x_15 = mul(x_7, x_14)

    return x_15

使用我们的自定义转换器

现在,我们可以重新编译并看到我们的自定义转换器正在被调用以将 GeLU 转换为 TensorRT。

my_custom_gelu = torch_tensorrt.compile(
    my_mod, arg_inputs=(ex_input,), min_block_size=1
)

print(my_custom_gelu.graph)
print(my_custom_gelu(ex_input))

我们可以验证我们的实现与 TensorRT 实现对于 tanh 近似的匹配情况。

print(
    f"tanh approximations are close: {torch.allclose(my_standard_gelu(ex_input), my_custom_gelu(ex_input))}"
)

最后,我们希望验证在 approximate 参数未设置为 tanh 的情况下,我们的自定义转换器不会被使用。

my_mod_erf = GeLU(mode="none")
my_gelu_erf = torch_tensorrt.compile(
    my_mod_erf, arg_inputs=(ex_input,), min_block_size=1
)

请注意,我们没有看到来自自定义转换器的打印语句,这表明它没有被使用。 但是,查看图形,我们仍然可以看到创建了 TensorRT 引擎来运行 GeLU 操作。 在这种情况下,我们自定义转换器的验证器返回了 False,因此转换系统继续执行列表中的下一个转换器,即标准 GeLU 转换器,并使用它来转换操作。

print(my_gelu_erf.graph)
print(my_gelu_erf(ex_input))

脚本的总运行时间:(0 分钟 0.000 秒)

Sphinx-Gallery 生成的画廊

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源