函数 torch_tensorrt::torchscript::embed_engine_in_new_module¶
函数文档¶
-
TORCHTRT_API torch::jit::Module torch_tensorrt::torchscript::embed_engine_in_new_module(const std::string &engine, Device device, const std::vector<std::string> &input_binding_names = std::vector<std::string>(), const std::vector<std::string> &output_binding_names = std::vector<std::string>())¶
获取先前创建的 TensorRT 引擎并将其嵌入到 TorchScript 模块中。
获取预构建的序列化 TensorRT 引擎并将其嵌入到 TorchScript 模块中。将引擎的执行注册为模块的前向方法。前向方法定义为:forward(Tensor[]) -> Tensor[]
如果未指定绑定名称,则 TensorRT 绑定必须具有以下格式的名称
[符号].[输入/输出数组中的索引] 例如:
[x.0, x.1, x.2] -> [y.0]
- 参数
engine – std::string - 预构建的序列化 TensorRT 引擎
device – CompileSepc::Device - Device 设备信息
input_binding_names – std::vector<std::string> - TensorRT 绑定的名称,按照原始 PyTorch 函数传入的顺序排列(默认为假设以下约定)
output_binding_names – std::vector<std::string> - TensorRT 绑定的名称,按照原始 PyTorch 函数返回的顺序排列(默认为假设以下约定)
- 返回值
:一个以 TensorRT 引擎为目标的新模块