• 文档 >
  • 在 Python 中使用 Torch-TensorRT
快捷方式

在 Python 中使用 Torch-TensorRT

与仅支持 TorchScript 编译的 CLI 和 C++ API 相比,Torch-TensorRT Python API 支持多种独特的用例。

Torch-TensorRT Python API 可以接受一个 torch.nn.Moduletorch.jit.ScriptModuletorch.fx.GraphModule 作为输入。根据提供的类型,将选择两个前端(TorchScript 或 FX)之一来编译模块。只要模块类型受支持,用户可以使用 ir 标志显式设置他们想要使用的前端 compile。如果给定一个 torch.nn.Module 并且 ir 标志设置为 defaulttorchscript,则该模块将通过 torch.jit.script 运行,以将输入模块转换为 TorchScript 模块。

要使用 Torch-TensorRT 编译您的输入 torch.nn.Module,您只需提供模块和输入到 Torch-TensorRT,您将获得一个优化的 TorchScript 模块来运行或添加到另一个 PyTorch 模块中。Inputs 是一个 torch_tensorrt.Input 类列表,它定义了输入张量的形状、数据类型和内存格式。或者,如果您的输入是更复杂的数据类型,例如张量元组或列表,您可以使用 input_signature 参数指定基于集合的输入,例如 (List[Tensor], Tuple[Tensor, Tensor])。有关示例,请参见下面的第二个示例。您还可以指定引擎或目标设备的操作精度等设置。编译后,您可以像保存任何其他模块一样保存该模块,以便在部署应用程序中加载。为了加载 TensorRT/TorchScript 模块,请确保首先导入 torch_tensorrt

import torch_tensorrt

...

model = MyModel().eval()  # torch module needs to be in eval (not training) mode

inputs = [
    torch_tensorrt.Input(
        min_shape=[1, 1, 16, 16],
        opt_shape=[1, 1, 32, 32],
        max_shape=[1, 1, 64, 64],
        dtype=torch.half,
    )
]
enabled_precisions = {torch.float, torch.half}  # Run with fp16

trt_ts_module = torch_tensorrt.compile(
    model, inputs=inputs, enabled_precisions=enabled_precisions
)

input_data = input_data.to("cuda").half()
result = trt_ts_module(input_data)
torch.jit.save(trt_ts_module, "trt_ts_module.ts")
# Sample using collection-based inputs via the input_signature argument
import torch_tensorrt

...

model = MyModel().eval()

# input_signature expects a tuple of individual input arguments to the module
# The module below, for example, would have a docstring of the form:
# def forward(self, input0: List[torch.Tensor], input1: Tuple[torch.Tensor, torch.Tensor])
input_signature = (
    [torch_tensorrt.Input(shape=[64, 64], dtype=torch.half), torch_tensorrt.Input(shape=[64, 64], dtype=torch.half)],
    (torch_tensorrt.Input(shape=[64, 64], dtype=torch.half), torch_tensorrt.Input(shape=[64, 64], dtype=torch.half)),
)
enabled_precisions = {torch.float, torch.half}

trt_ts_module = torch_tensorrt.compile(
    model, input_signature=input_signature, enabled_precisions=enabled_precisions
)

input_data = input_data.to("cuda").half()
result = trt_ts_module(input_data)
torch.jit.save(trt_ts_module, "trt_ts_module.ts")
# Deployment application
import torch
import torch_tensorrt

trt_ts_module = torch.jit.load("trt_ts_module.ts")
input_data = input_data.to("cuda").half()
result = trt_ts_module(input_data)

Torch-TensorRT Python API 还提供了 torch_tensorrt.ts.compile,它接受一个 TorchScript 模块作为输入,以及 torch_tensorrt.fx.compile,它接受一个 FX GraphModule 作为输入。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源