• 文档 >
  • 使用 Torch-TensorRT TorchScript 前端直接从 PyTorch
快捷方式

使用 Torch-TensorRT TorchScript 前端直接从 PyTorch

现在您将能够直接从 PyTorch API 访问 TensorRT。使用此功能的过程与 在 Python 中使用 Torch-TensorRT 中描述的编译工作流程非常相似

首先将 torch_tensorrt 加载到您的应用程序中。

import torch
import torch_tensorrt

然后,给定一个 TorchScript 模块,您可以使用 torch._C._jit_to_backend("tensorrt", ...) API 使用 TensorRT 编译它。

import torchvision.models as models

model = models.mobilenet_v2(pretrained=True)
script_model = torch.jit.script(model)

与 Torch-TensorRT 中的 compile API 不同,后者假设您尝试编译模块的 forward 函数或 convert_method_to_trt_engine,后者将指定函数转换为 TensorRT 引擎,后端 API 将采用一个字典,该字典将要编译的函数名称映射到 Compilation Spec 对象,这些对象包装了您将提供给 compile 的相同类型的字典。有关编译规范字典的更多信息,请查看 Torch-TensorRT TensorRTCompileSpec API 的文档。

spec = {
    "forward": torch_tensorrt.ts.TensorRTCompileSpec(
        **{
            "inputs": [torch_tensorrt.Input([1, 3, 300, 300])],
            "enabled_precisions": {torch.float, torch.half},
            "refit": False,
            "debug": False,
            "device": {
                "device_type": torch_tensorrt.DeviceType.GPU,
                "gpu_id": 0,
                "dla_core": 0,
                "allow_gpu_fallback": True,
            },
            "capability": torch_tensorrt.EngineCapability.default,
            "num_avg_timing_iters": 1,
        }
    )
}

现在要使用 Torch-TensorRT 进行编译,请将目标模块对象和规范字典提供给 torch._C._jit_to_backend("tensorrt", ...)

trt_model = torch._C._jit_to_backend("tensorrt", script_model, spec)

要显式运行,请调用要运行的方法的函数(而不是在标准 PyTorch 中可以直接在模块本身上调用)

input = torch.randn((1, 3, 300, 300)).to("cuda").to(torch.half)
print(trt_model.forward(input))

© 版权所有 2024,NVIDIA Corporation。

使用 Sphinx 构建,主题由 theme 提供,由 Read the Docs 提供。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题解答

查看资源