• 文档 >
  • 在 C++ 中使用 Torch-TensorRT
快捷方式

在 C++ 中使用 Torch-TensorRT

如果您还没有这样做,请按照安装说明获取库的 tar 包。

在 C++ 中使用 Torch-TensorRT

Torch-TensorRT C++ API 接受 TorchScript 模块(由 torch.jit.scripttorch.jit.trace 生成)作为输入,并返回一个 Torchscript 模块(使用 TensorRT 优化)。但是,C++ API 将不支持 Dynamo 编译工作流,但支持对 FX 和 Dynamo 工作流执行由 torch.jit.trace 生成的已编译 FX GraphModules。

请参阅在 Python 中创建 TorchScript 模块一节来生成 torchscript 图。

[Torch-TensorRT 快速入门] 使用 torchtrtc 编译 TorchScript 模块

开始使用 Torch-TensorRT 并检查您的模型是否无需额外工作即可受支持的一种简单方法是通过 torchtrtc 运行它。 torchtrtc 支持编译器几乎所有的命令行功能,包括训练后量化(如果提供预先创建的校准缓存)。例如,我们可以通过设置首选的操作精度和输入大小来编译我们的 lenet 模型。这个新的 TorchScript 文件可以加载到 Python 中(注意:在加载这些已编译模块之前,您需要 import torch_tensorrt,因为编译器扩展了 PyTorch 的反序列化器和运行时来执行已编译的模块)。

 torchtrtc -p f16 lenet_scripted.ts trt_lenet_scripted.ts "(1,1,32,32)" python3
Python 3.6.9 (default, Apr 18 2020, 01:56:04)
[GCC 8.4.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> import torch_tensorrt
>>> ts_model = torch.jit.load(“trt_lenet_scripted.ts”)
>>> ts_model(torch.randn((1,1,32,32)).to(“cuda”).half())

您可以在此处了解更多关于 torchtrtc 的用法:torchtrtc

在 C++ 中使用 TorchScript

如果我们正在开发一个使用 C++ 部署的应用程序,我们可以使用 torch.jit.save 保存我们的跟踪(traced)或脚本化(scripted)模块,它会将 TorchScript 代码、权重和其他信息序列化到一个包中。这也是我们对 Python 的依赖结束的地方。

torch_script_module.save("lenet.jit.pt")

接下来,我们可以在 C++ 中加载我们的 TorchScript 模块。

#include <torch/script.h> // One-stop header.

#include <iostream>
#include <memory>

int main(int argc, const char* argv[]) {
    torch::jit::Module module;
    try {
        // Deserialize the ScriptModule from a file using torch::jit::load().
        module = torch::jit::load("<PATH TO SAVED TS MOD>");
    }
    catch (const c10::Error& e) {
        std::cerr << "error loading the model\n";
        return -1;
    }

    std::cout << "ok\n";

如果您愿意,您可以使用 PyTorch / LibTorch 在 C++ 中进行完整的训练和推理,您甚至可以在 C++ 中定义您的模块并访问支持 PyTorch 的强大张量库。(更多信息请参阅:https://pytorch.ac.cn/cppdocs/)。例如,我们可以使用我们的 LeNet 模块进行如下推理:

mod.eval();
torch::Tensor in = torch::randn({1, 1, 32, 32});
auto out = mod.forward(in);

并在 GPU 上运行

mod.eval();
mod.to(torch::kCUDA);
torch::Tensor in = torch::randn({1, 1, 32, 32}, torch::kCUDA);
auto out = mod.forward(in);

正如您所见,这与 Python API 非常相似。当您调用 forward 方法时,您会调用 PyTorch JIT 编译器,它将优化并运行您的 TorchScript 代码。

在 C++ 中使用 Torch-TensorRT 进行编译

我们现在也可以使用 Torch-TensorRT 编译和优化我们的模块,但不是以 JIT(即时)方式,而是必须采用 AOT(提前)方式进行,即在我们开始实际推理工作之前进行,因为优化模块需要一些时间,每次运行模块甚至首次运行时都这样做没有意义。

加载模块后,我们可以将其输入到 Torch-TensorRT 编译器。执行此操作时,我们必须提供有关预期输入大小的一些信息,并配置任何附加设置。

#include "torch/script.h"
#include "torch_tensorrt/torch_tensorrt.h"
...

    mod.to(at::kCUDA);
    mod.eval();
    std::vector<torch_tensorrt::core::ir::Input> inputs{torch_tensorrt::core::ir::Input({1, 3, 224, 224})};
    torch_tensorrt::ts::CompileSpec cfg(inputs);
    auto trt_mod = torch_tensorrt::ts::compile(mod, cfg);
    auto in = torch::randn({1, 3, 224, 224}, {torch::kCUDA});
    auto out = trt_mod.forward({in});

就这样!现在图主要不是由 JIT 编译器运行,而是使用 TensorRT(尽管我们仍然使用 JIT 运行时执行图)。

我们还可以设置操作精度等选项,以便在 FP16 中运行。

#include "torch/script.h"
#include "torch_tensorrt/torch_tensorrt.h"
...

    mod.to(at::kCUDA);
    mod.eval();

    auto in = torch::randn({1, 3, 224, 224}, {torch::kCUDA}).to(torch::kHALF);
    std::vector<torch_tensorrt::core::ir::Input> inputs{torch_tensorrt::core::ir::Input({1, 3, 224, 224})};
    torch_tensorrt::ts::CompileSpec cfg(inputs);
    cfg.enable_precisions.insert(torch::kHALF);
    auto trt_mod = torch_tensorrt::ts::compile(mod, cfg);
    auto out = trt_mod.forward({in});

现在我们正在以 FP16 精度运行模块。您可以将模块保存起来以便以后加载。

trt_mod.save("<PATH TO SAVED TRT/TS MOD>")

Torch-TensorRT 编译的 TorchScript 模块与普通 TorchScript 模块的加载方式相同。确保您的部署应用程序链接到 libtorchtrt.so

#include "torch/script.h"
#include "torch_tensorrt/torch_tensorrt.h"

int main(int argc, const char* argv[]) {
    torch::jit::Module module;
    try {
        // Deserialize the ScriptModule from a file using torch::jit::load().
        module = torch::jit::load("<PATH TO SAVED TRT/TS MOD>");
    }
    catch (const c10::Error& e) {
        std::cerr << "error loading the model\n";
        return -1;
    }

    torch::Tensor in = torch::randn({1, 1, 32, 32}, torch::kCUDA);
    auto out = mod.forward(in);

    std::cout << "ok\n";
}

如果您想保存 Torch-TensorRT 生成的引擎以便在 TensorRT 应用程序中使用,您可以使用 ConvertGraphToTRTEngine API。

#include "torch/script.h"
#include "torch_tensorrt/torch_tensorrt.h"
...

    mod.to(at::kCUDA);
    mod.eval();

    auto in = torch::randn({1, 3, 224, 224}, {torch::kCUDA}).to(torch::kHALF);

    std::vector<torch_tensorrt::core::ir::Input> inputs{torch_tensorrt::core::ir::Input({1, 3, 224, 224})};
    torch_tensorrt::ts::CompileSpec cfg(inputs);
    cfg.enabled_precisions.insert(torch::kHALF);
    auto trt_mod = torch_tensorrt::ts::convert_method_to_trt_engine(mod, "forward", cfg);
    std::ofstream out("/tmp/engine_converted_from_jit.trt");
    out << engine;
    out.close();

幕后

当一个模块提供给 Torch-TensorRT 时,编译器首先会将您在上方看到的图映射到如下所示的图:

graph(%input.2 : Tensor):
    %2 : Float(84, 10) = prim::Constant[value=<Tensor>]()
    %3 : Float(120, 84) = prim::Constant[value=<Tensor>]()
    %4 : Float(576, 120) = prim::Constant[value=<Tensor>]()
    %5 : int = prim::Constant[value=-1]() # x.py:25:0
    %6 : int[] = prim::Constant[value=annotate(List[int], [])]()
    %7 : int[] = prim::Constant[value=[2, 2]]()
    %8 : int[] = prim::Constant[value=[0, 0]]()
    %9 : int[] = prim::Constant[value=[1, 1]]()
    %10 : bool = prim::Constant[value=1]() # ~/.local/lib/python3.6/site-packages/torch/nn/modules/conv.py:346:0
    %11 : int = prim::Constant[value=1]() # ~/.local/lib/python3.6/site-packages/torch/nn/functional.py:539:0
    %12 : bool = prim::Constant[value=0]() # ~/.local/lib/python3.6/site-packages/torch/nn/functional.py:539:0
    %self.classifier.fc3.bias : Float(10) = prim::Constant[value= 0.0464  0.0383  0.0678  0.0932  0.1045 -0.0805 -0.0435 -0.0818  0.0208 -0.0358 [ CUDAFloatType{10} ]]()
    %self.classifier.fc2.bias : Float(84) = prim::Constant[value=<Tensor>]()
    %self.classifier.fc1.bias : Float(120) = prim::Constant[value=<Tensor>]()
    %self.feat.conv2.weight : Float(16, 6, 3, 3) = prim::Constant[value=<Tensor>]()
    %self.feat.conv2.bias : Float(16) = prim::Constant[value=<Tensor>]()
    %self.feat.conv1.weight : Float(6, 1, 3, 3) = prim::Constant[value=<Tensor>]()
    %self.feat.conv1.bias : Float(6) = prim::Constant[value= 0.0530 -0.1691  0.2802  0.1502  0.1056 -0.1549 [ CUDAFloatType{6} ]]()
    %input0.4 : Tensor = aten::_convolution(%input.2, %self.feat.conv1.weight, %self.feat.conv1.bias, %9, %8, %9, %12, %8, %11, %12, %12, %10) # ~/.local/lib/python3.6/site-packages/torch/nn/modules/conv.py:346:0
    %input0.5 : Tensor = aten::relu(%input0.4) # ~/.local/lib/python3.6/site-packages/torch/nn/functional.py:1063:0
    %input1.2 : Tensor = aten::max_pool2d(%input0.5, %7, %6, %8, %9, %12) # ~/.local/lib/python3.6/site-packages/torch/nn/functional.py:539:0
    %input0.6 : Tensor = aten::_convolution(%input1.2, %self.feat.conv2.weight, %self.feat.conv2.bias, %9, %8, %9, %12, %8, %11, %12, %12, %10) # ~/.local/lib/python3.6/site-packages/torch/nn/modules/conv.py:346:0
    %input2.1 : Tensor = aten::relu(%input0.6) # ~/.local/lib/python3.6/site-packages/torch/nn/functional.py:1063:0
    %x.1 : Tensor = aten::max_pool2d(%input2.1, %7, %6, %8, %9, %12) # ~/.local/lib/python3.6/site-packages/torch/nn/functional.py:539:0
    %input.1 : Tensor = aten::flatten(%x.1, %11, %5) # x.py:25:0
    %27 : Tensor = aten::matmul(%input.1, %4)
    %28 : Tensor = trt::const(%self.classifier.fc1.bias)
    %29 : Tensor = aten::add_(%28, %27, %11)
    %input0.2 : Tensor = aten::relu(%29) # ~/.local/lib/python3.6/site-packages/torch/nn/functional.py:1063:0
    %31 : Tensor = aten::matmul(%input0.2, %3)
    %32 : Tensor = trt::const(%self.classifier.fc2.bias)
    %33 : Tensor = aten::add_(%32, %31, %11)
    %input1.1 : Tensor = aten::relu(%33) # ~/.local/lib/python3.6/site-packages/torch/nn/functional.py:1063:0
    %35 : Tensor = aten::matmul(%input1.1, %2)
    %36 : Tensor = trt::const(%self.classifier.fc3.bias)
    %37 : Tensor = aten::add_(%36, %35, %11)
    return (%37)
(CompileGraph)

现在,图已经从管理各自参数的模块集合转换为一个单一的图,参数被内联到图中,并且所有操作都已布置妥当。Torch-TensorRT 还执行了许多优化和映射,使图更容易转换为 TensorRT。从这里开始,编译器可以通过遵循图中的数据流来组装 TensorRT 引擎。

图构建阶段完成后,Torch-TensorRT 会生成一个序列化的 TensorRT 引擎。从这里开始,根据所使用的 API,该引擎会被返回给用户或进入图构建阶段。在这里,Torch-TensorRT 会创建一个 JIT 模块来执行 TensorRT 引擎,该引擎将由 Torch-TensorRT 运行时实例化和管理。

以下是编译完成后您获得的图:

graph(%self_1 : __torch__.lenet, %input_0 : Tensor):
    %1 : ...trt.Engine = prim::GetAttr[name="lenet"](%self_1)
    %3 : Tensor[] = prim::ListConstruct(%input_0)
    %4 : Tensor[] = trt::execute_engine(%3, %1)
    %5 : Tensor = prim::ListUnpack(%4)
    return (%5)

您可以看到执行引擎的调用,在提取包含引擎的属性并构建输入列表后,它会将张量返回给用户。

处理不受支持的操作符

Torch-TensorRT 是一个新库,而 PyTorch 的操作符库非常庞大,因此有些操作符编译器本身并不支持。您可以采用上面所示的组合技术,将完全受 Torch-TensorRT 支持的模块与不受支持的模块分开,然后在部署应用程序中将这些模块拼接在一起;或者,您可以为缺失的操作符注册转换器。

您可以使用 torch_tensorrt::CheckMethodOperatorSupport(const torch::jit::Module& module, std::string method_name) API 来检查支持情况,而无需执行完整的编译流程,以查看哪些操作符不受支持。 torchtrtc 在开始编译之前会自动使用此方法检查模块,并打印出不受支持的操作符列表。

注册自定义转换器

操作符通过模块化转换器映射到 TensorRT,转换器是一个函数,它接收 JIT 图中的一个节点,并在 TensorRT 中生成等效的层或子图。Torch-TensorRT 自带一个转换器库,存储在一个注册表中,将根据被解析的节点执行相应的转换器。例如,aten::relu(%input0.4) 指令会触发 relu 转换器在其上运行,从而在 TensorRT 图中生成一个激活层。但由于这个库并不详尽,您可能需要编写自己的转换器以使 Torch-TensorRT 支持您的模块。

Torch-TensorRT 发行版中附带了内部核心 API 头文件。因此,您可以访问转换器注册表并添加您需要的操作符的转换器。

例如,如果我们尝试使用不支持展平(flatten)操作符(aten::flatten)的 Torch-TensorRT 构建版本编译图,您可能会看到以下错误:

terminate called after throwing an instance of 'torch_tensorrt::Error'
what():  [enforce fail at core/conversion/conversion.cpp:109] Expected converter to be true but got false
Unable to convert node: %input.1 : Tensor = aten::flatten(%x.1, %11, %5) # x.py:25:0 (conversion.AddLayer)
Schema: aten::flatten.using_ints(Tensor self, int start_dim=0, int end_dim=-1) -> (Tensor)
Converter for aten::flatten requested, but no such converter was found.
If you need a converter for this operator, you can try implementing one yourself
or request a converter: https://www.github.com/NVIDIA/Torch-TensorRT/issues

我们可以在应用程序中为这个操作符注册一个转换器。所有构建转换器所需的工具都可以通过包含 torch_tensorrt/core/conversion/converters/converters.h 导入。我们首先创建一个自注册类 torch_tensorrt::core::conversion::converters::RegisterNodeConversionPatterns() 的实例,它会在全局转换器注册表中注册转换器,将一个函数 schema(例如 aten::flatten.using_ints(Tensor self, int start_dim=0, int end_dim=-1) -> (Tensor))与一个 lambda 函数关联起来。这个 lambda 函数将接收转换状态、要转换的节点/操作以及节点的所有输入,其副作用是在 TensorRT 网络中生成一个新的层。参数以 TensorRT ITensors 和 Torch IValues 可检查联合的向量形式传递,顺序与 schema 中列出的参数顺序一致。

下面是一个 aten::flatten 转换器的实现示例,我们可以在应用程序中使用它。在转换器实现中,您可以完全访问 Torch 和 TensorRT 库。因此,例如,我们可以通过在 PyTorch 中直接运行操作来快速获取输出大小,而不是像我们下面为这个展平转换器那样自己实现完整的计算。

#include "torch/script.h"
#include "torch_tensorrt/torch_tensorrt.h"
#include "torch_tensorrt/core/conversion/converters/converters.h"

static auto flatten_converter = torch_tensorrt::core::conversion::converters::RegisterNodeConversionPatterns()
    .pattern({
        "aten::flatten.using_ints(Tensor self, int start_dim=0, int end_dim=-1) -> (Tensor)",
        [](torch_tensorrt::core::conversion::ConversionCtx* ctx,
           const torch::jit::Node* n,
           torch_tensorrt::core::conversion::converters::args& args) -> bool {
            auto in = args[0].ITensor();
            auto start_dim = args[1].unwrapToInt();
            auto end_dim = args[2].unwrapToInt();
            auto in_shape = torch_tensorrt::core::util::toVec(in->getDimensions());
            auto out_shape = torch::flatten(torch::rand(in_shape), start_dim, end_dim).sizes();

            auto shuffle = ctx->net->addShuffle(*in);
            shuffle->setReshapeDimensions(torch_tensorrt::core::util::toDims(out_shape));
            shuffle->setName(torch_tensorrt::core::util::node_info(n).c_str());

            auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], shuffle->getOutput(0));
            return true;
        }
    });

int main() {
    ...

要在 Python 中使用此转换器,建议使用 PyTorch 的C++ / CUDA 扩展模板将您的转换器库包装成一个 .so 文件,您可以在 Python 应用程序中使用 ctypes.CDLL() 加载它。

您可以在贡献者文档(writing_converters)中找到关于编写转换器所有细节的更多信息。如果您有很多转换器实现,请考虑将它们贡献回上游,我们欢迎拉取请求(PR),这将极大地造福社区。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得问题解答

查看资源