• 文档 >
  • Torch-TensorRT(FX 前端)用户指南
快捷方式

Torch-TensorRT(FX 前端)用户指南

Torch-TensorRT(FX 前端)是一种工具,可以通过 torch.fx 将 PyTorch 模型转换为 TensorRT 引擎,该引擎针对在 Nvidia GPU 上运行进行了优化。TensorRT 是 NVIDIA 开发的推理引擎,它包含各种优化,包括内核融合、图优化、低精度等。该工具是在 Python 环境中开发的,这使得该工作流程对于研究人员和工程师来说非常容易使用。用户可以使用该工具完成几个步骤,我们将在本文档中介绍这些步骤。

> Torch-TensorRT(FX 前端)处于 Beta 阶段,目前建议使用 PyTorch nightly 版本。

# Test an example by
$ python py/torch_tensorrt/fx/example/lower_example.py

将 PyTorch 模型转换为 TensorRT 引擎

通常,用户可以使用 compile() 完成从模型到 TensorRT 引擎的转换。它是一个包装器 API,包含完成此转换所需的几个主要步骤。请参阅 examples/fx 目录下的 lower_example.py 文件中的示例用法。

def compile(
    module: nn.Module,
    input,
    max_batch_size=2048,
    max_workspace_size=33554432,
    explicit_batch_dimension=False,
    lower_precision=LowerPrecision.FP16,
    verbose_log=False,
    timing_cache_prefix="",
    save_timing_cache=False,
    cuda_graph_batch_size=-1,
    dynamic_batch=True,
) -> nn.Module:

    """
    Takes in original module, input and lowering setting, run lowering workflow to turn module
    into lowered module, or so called TRTModule.

    Args:
        module: Original module for lowering.
        input: Input for module.
        max_batch_size: Maximum batch size (must be >= 1 to be set, 0 means not set)
        max_workspace_size: Maximum size of workspace given to TensorRT.
        explicit_batch_dimension: Use explicit batch dimension in TensorRT if set True, otherwise use implicit batch dimension.
        lower_precision: lower_precision config given to TRTModule.
        verbose_log: Enable verbose log for TensorRT if set True.
        timing_cache_prefix: Timing cache file name for timing cache used by fx2trt.
        save_timing_cache: Update timing cache with current timing cache data if set to True.
        cuda_graph_batch_size: Cuda graph batch size, default to be -1.
        dynamic_batch: batch dimension (dim=0) is dynamic.
    Returns:
        A torch.nn.Module lowered by TensorRT.
    """

在本节中,我们将通过一个示例来说明 fx 路径使用的主要步骤。用户可以参考 examples/fx 目录下的 fx2trt_example.py 文件。

  • 步骤 1:使用 acc_tracer 追踪模型

Acc_tracer 是从 FX 追踪器继承来的追踪器。它带有一个参数规范器,用于将所有参数转换为关键字参数并传递给 TRT 转换器。

import torch_tensorrt.fx.tracer.acc_tracer.acc_tracer as acc_tracer

# Build the model which needs to be a PyTorch nn.Module.
my_pytorch_model = build_model()

# Prepare inputs to the model. Inputs have to be a List of Tensors
inputs = [Tensor, Tensor, ...]

# Trace the model with acc_tracer.
acc_mod = acc_tracer.trace(my_pytorch_model, inputs)

常见错误

符号追踪的变量不能用作控制流的输入,这意味着模型包含动态控制流。请参考 FX 指南 中的“动态控制流”一节。

  • 步骤 2:构建 TensorRT 引擎

TensorRT 处理批次维度的方式有 两种不同的模式,即显式批次维度和隐式批次维度。这种模式在早期版本的 TensorRT 中使用,现在已弃用,但为了向后兼容,仍然支持。在显式批次模式下,所有维度都是显式的,并且可以是动态的,也就是说它们在运行时可以改变长度。许多新功能,例如动态形状和循环,只在该模式下可用。用户仍然可以选择在 compile() 中设置 explicit_batch_dimension=False 来使用隐式批次模式。我们不建议使用它,因为它在将来的 TensorRT 版本中将不再支持。

显式批次是默认模式,它必须为动态形状设置。对于大多数视觉任务,用户可以选择在 compile() 中启用 dynamic_batch,如果他们想获得与隐式模式类似的效果,在隐式模式下,只有批次维度会改变。它有一些要求:1. 输入、输出和激活的形状除了批次维度之外都是固定的。2. 输入、输出和激活的批次维度是主维度。3. 模型中的所有运算符不会修改批次维度(置换、转置、分割等)或在批次维度上计算(求和、softmax 等)。

例如,对于最后一个路径,如果我们有一个 3D 张量 t,其形状为 (batch, sequence, dimension),则 torch.transpose(0, 2) 等运算将被禁止。如果这三个条件中的任何一个不满足,我们需要将 InputTensorSpec 作为输入,并指定动态范围。

import deeplearning.trt.fx2trt.converter.converters
from torch.fx.experimental.fx2trt.fx2trt import InputTensorSpec, TRTInterpreter

# InputTensorSpec is a dataclass we use to store input information.
# There're two ways we can build input_specs.
# Option 1, build it manually.
input_specs = [
  InputTensorSpec(shape=(1, 2, 3), dtype=torch.float32),
  InputTensorSpec(shape=(1, 4, 5), dtype=torch.float32),
]
# Option 2, build it using sample_inputs where user provide a sample
inputs = [
torch.rand((1,2,3), dtype=torch.float32),
torch.rand((1,4,5), dtype=torch.float32),
]
input_specs = InputTensorSpec.from_tensors(inputs)

# IMPORTANT: If dynamic shape is needed, we need to build it slightly differently.
input_specs = [
    InputTensorSpec(
        shape=(-1, 2, 3),
        dtype=torch.float32,
        # Currently we only support one set of dynamic range. User may set other dimensions but it is not promised to work for any models
        # (min_shape, optimize_target_shape, max_shape)
        # For more information refer to fx/input_tensor_spec.py
        shape_ranges = [
            ((1, 2, 3), (4, 2, 3), (100, 2, 3)),
        ],
    ),
    InputTensorSpec(shape=(1, 4, 5), dtype=torch.float32),
]

# Build a TRT interpreter. Set explicit_batch_dimension accordingly.
interpreter = TRTInterpreter(
    acc_mod, input_specs, explicit_batch_dimension=True/False
)

# The output of TRTInterpreter run() is wrapped as TRTInterpreterResult.
# The TRTInterpreterResult contains required parameter to build TRTModule,
# and other informational output from TRTInterpreter run.
class TRTInterpreterResult(NamedTuple):
    engine: Any
    input_names: Sequence[str]
    output_names: Sequence[str]
    serialized_cache: bytearray

#max_batch_size: set accordingly for maximum batch size you will use.
#max_workspace_size: set to the maximum size we can afford for temporary buffer
#lower_precision: the precision model layers are running on (TensorRT will choose the best perforamnce precision).
#sparse_weights: allow the builder to examine weights and use optimized functions when weights have suitable sparsity
#force_fp32_output: force output to be fp32
#strict_type_constraints: Usually we should set it to False unless we want to control the precision of certain layer for numeric #reasons.
#algorithm_selector: set up algorithm selection for certain layer
#timing_cache: enable timing cache for TensorRT
#profiling_verbosity: TensorRT logging level
trt_interpreter_result = interpreter.run(
    max_batch_size=64,
    max_workspace_size=1 << 25,
    sparse_weights=False,
    force_fp32_output=False,
    strict_type_constraints=False,
    algorithm_selector=None,
    timing_cache=None,
    profiling_verbosity=None,
)

常见错误

RuntimeError: Conversion of function xxx not currently supported! - 这意味着我们目前不支持 xxx 运算符。请参考下面的“如何添加缺失的运算符”一节以获取更多说明。

  • 步骤 3:运行模型

一种方法是使用 TRTModule,它基本上是一个 PyTorch nn.Module。

from torch_tensorrt.fx import TRTModule
mod = TRTModule(
    trt_interpreter_result.engine,
    trt_interpreter_result.input_names,
    trt_interpreter_result.output_names)
# Just like all other PyTorch modules
outputs = mod(*inputs)
torch.save(mod, "trt.pt")
reload_trt_mod = torch.load("trt.pt")
reload_model_output = reload_trt_mod(*inputs)

到目前为止,我们详细解释了将 PyTorch 模型转换为 TensorRT 引擎的主要步骤。欢迎用户参考源代码了解某些参数的解释。在转换方案中,有两个重要的操作。一个是 acc 追踪器,它可以帮助我们将 PyTorch 模型转换为 acc 图。另一个是 FX 路径转换器,它可以帮助将 acc 图的运算转换为相应的 TensorRT 运算,并为其构建 TensorRT 引擎。

Acc 追踪器

Acc 追踪器是一个自定义 FX 符号追踪器。与标准 FX 符号追踪器相比,它还做了几件事。我们主要依靠它来将 PyTorch 运算或内置运算转换为 acc 运算。fx2trt 使用 acc 运算有两种主要目的

  1. PyTorch 运算和内置运算中有很多运算执行类似的操作,例如 torch.add、builtin.add 和 torch.Tensor.add。使用 acc 追踪器,我们将这三个运算标准化为一个单独的 acc_ops.add。这有助于减少我们需要编写的转换器数量。

  2. acc 运算只有关键字参数,这使得编写转换器更容易,因为我们不需要添加额外的逻辑来在参数和关键字参数中查找参数。

FX2TRT

完成符号追踪后,我们获得了 PyTorch 模型的图表示。fx2trt 利用了 fx.Interpreter 的强大功能。fx.Interpreter 依次遍历整个图的节点,并调用节点表示的函数。fx2trt 覆盖了调用函数的原始行为,改为调用每个节点的相应转换器。每个转换器函数都会添加相应的 TensorRT 层。

下面是一个转换器函数的示例。装饰器用于将该转换器函数注册到相应的节点。在此示例中,我们将此转换器注册到一个 fx 节点,该节点的目标是 acc_ops.sigmoid。

@tensorrt_converter(acc_ops.sigmoid)
def acc_ops_sigmoid(network, target, args, kwargs, name):
    """
    network: TensorRT network. We'll be adding layers to it.

    The rest arguments are attributes of fx node.
    """
    input_val = kwargs['input']

    if not isinstance(input_val, trt.tensorrt.ITensor):
        raise RuntimeError(f'Sigmoid received input {input_val} that is not part '
                        'of the TensorRT region!')

    layer = network.add_activation(input=input_val, type=trt.ActivationType.SIGMOID)
    layer.name = name
    return layer.get_output(0)

如何添加缺失的运算符

您实际上可以在任何地方添加它,只需记住导入该文件,以便在使用 acc_tracer 追踪之前注册所有 acc 运算和映射器。

  • 步骤 1. 添加新的 acc 运算

TODO: 需要进一步解释 acc 运算的逻辑,例如,当我们想要分解运算符时,以及当我们想要重用其他运算符时。

acc 追踪器 中,如果图中的节点到 acc 运算的映射已注册,我们将把该节点转换为 acc 运算。

为了使转换为 acc 运算的过程发生,需要两个条件。一个条件是应定义 acc 运算函数,另一个条件是应注册映射。

定义一个 acc 操作很简单,我们只需要一个函数,并使用这个装饰器 acc_normalizer.py 将其注册为一个 acc 操作。例如,以下代码添加了一个名为 foo() 的 acc 操作,它将两个给定的输入相加。

# NOTE: all acc ops should only take kwargs as inputs, therefore we need the "*"
# at the beginning.
@register_acc_op
def foo(*, input, other, alpha):
    return input + alpha * other

有两种方法可以注册映射。一种是 register_acc_op_mapping()。让我们注册一个从 torch.add 到我们上面创建的 foo() 的映射。我们需要将装饰器 register_acc_op_mapping 添加到它。

this_arg_is_optional = True

@register_acc_op_mapping(
    op_and_target=("call_function", torch.add),
    arg_replacement_tuples=[
        ("input", "input"),
        ("other", "other"),
        ("alpha", "alpha", this_arg_is_optional),
    ],
)
@register_acc_op
def foo(*, input, other, alpha=1.0):
    return input + alpha * other

op_and_target 决定哪个节点将触发此映射。op 和 target 是 FX 节点的属性。在 acc_normalization 中,当我们看到一个节点,其 op 和 target 与 op_and_target 中设置的相同,我们将触发映射。由于我们希望从 torch.add 进行映射,那么 op 将是 call_function,target 将是 torch.addarg_replacement_tuples 决定了我们如何使用来自原始节点的 args 和 kwargs 为新的 acc 操作节点构建 kwargs。 arg_replacement_tuples 中的每个元组代表一个参数映射规则。它包含两个或三个元素。第三个元素是一个布尔变量,用于确定此 kwarg 在原始节点中是否可选。只有在它为 True 时,我们才需要指定第三个元素。第一个元素是原始节点中的参数名称,它将用作 acc 操作节点的参数名称,该节点的名称是元组中的第二个元素。元组的顺序很重要,因为元组的位置决定了参数在原始节点的 args 中的位置。我们使用这些信息将 args 从原始节点映射到 acc 操作节点中的 kwargs。如果以下情况都不满足,则我们不必指定 arg_replacement_tuples。

  1. 原始节点和 acc 操作节点的 kwargs 具有不同的名称。

  2. 存在可选参数。

注册映射的另一种方法是通过 register_custom_acc_mapper_fn()。这种方法旨在减少冗余的操作注册,因为它允许您使用一个函数通过一些组合映射到一个或多个现有的 acc 操作。在函数中,您可以做任何您想做的事情。让我们用一个例子来解释它的工作原理。

@register_acc_op
def foo(*, input, other, alpha=1.0):
    return input + alpha * other

@register_custom_acc_mapper_fn(
    op_and_target=("call_function", torch.add),
    arg_replacement_tuples=[
        ("input", "input"),
        ("other", "other"),
        ("alpha", "alpha", this_arg_is_optional),
    ],
)
def custom_mapper(node: torch.fx.Node, _: nn.Module) -> torch.fx.Node:
    """
    `node` is original node, which is a call_function node with target
    being torch.add.
    """
    alpha = 1
    if "alpha" in node.kwargs:
        alpha = node.kwargs["alpha"]
    foo_kwargs = {"input": node["input"], "other": node["other"], "alpha": alpha}
    with node.graph.inserting_before(node):
        foo_node = node.graph.call_function(foo, kwargs=foo_kwargs)
        foo_node.meta = node.meta.copy()
        return foo_node

在自定义映射器函数中,我们构建一个 acc 操作节点并返回它。我们在这里返回的节点将接管原始节点的所有子节点 acc_normalizer.py.

最后一步将是为我们添加的新 acc 操作或映射器函数添加单元测试。添加单元测试的位置在这里 test_acc_tracer.py.

  • 步骤 2. 添加一个新的转换器

为 acc 操作开发的所有转换器都位于 acc_op_converter.py 中。它可以为您提供一个有关如何添加转换器的良好示例。

本质上,转换器是将 acc 操作映射到 TensorRT 层的映射机制。如果我们能够找到所有需要的 TensorRT 层,我们可以开始使用 TensorRT API 为节点添加一个转换器。

@tensorrt_converter(acc_ops.sigmoid)
def acc_ops_sigmoid(network, target, args, kwargs, name):
    """
    network: TensorRT network. We'll be adding layers to it.

    The rest arguments are attributes of fx node.
    """
    input_val = kwargs['input']

    if not isinstance(input_val, trt.tensorrt.ITensor):
        raise RuntimeError(f'Sigmoid received input {input_val} that is not part '
                        'of the TensorRT region!')

    layer = network.add_activation(input=input_val, type=trt.ActivationType.SIGMOID)
    layer.name = name
    return layer.get_output(0)

我们需要使用 tensorrt_converter 装饰器来注册转换器。装饰器的参数是我们需要转换的 fx 节点的目标。在转换器中,我们可以在 kwargs 中找到 fx 节点的输入。例如,原始节点是 acc_ops.sigmoid,它在 acc_ops.py 中只有一个参数“input”。我们获取输入并检查它是否是一个 TensorRT 张量。之后,我们将一个 sigmoid 层添加到 TensorRT 网络并返回层的输出。我们返回的输出将由 fx.Interpreter 传递给 acc_ops.sigmoid 的子节点。

如果我们无法在 TensorRT 中找到与节点执行相同操作的对应层怎么办?

在这种情况下,我们需要做更多工作。TensorRT 提供插件,充当自定义层。我们还没有实现此功能。我们将在启用后更新

最后一步是为我们添加的新转换器添加单元测试。用户可以在此 文件夹 中添加相应的单元测试。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源