torch_tensorrt.fx¶
函数¶
- torch_tensorrt.fx.compile(module: Module, input, min_acc_module_size: int = 10, max_batch_size: int = 2048, max_workspace_size=33554432, explicit_batch_dimension=False, lower_precision=LowerPrecision.FP16, verbose_log=False, timing_cache_prefix='', save_timing_cache=False, cuda_graph_batch_size=- 1, dynamic_batch=True, is_aten=False, use_experimental_fx_rt=False, correctness_atol=0.1, correctness_rtol=0.1) Module [源代码]¶
获取原始模块、输入和下降设置,运行下降工作流将模块转换为下降后的模块,即所谓的 TRTModule。
- 参数
module – 原始模块,用于下降。
input – 模块的输入。
max_batch_size – 最大批量大小(必须 ≥ 1 才能设置,0 表示未设置)
min_acc_module_size – 加速子模块所需的最少节点数
max_workspace_size – 提供给 TensorRT 的最大工作空间大小。
explicit_batch_dimension – 如果设置为 True,在 TensorRT 中使用显式批量维度,否则使用隐式批量维度。
lower_precision – 提供给 TRTModule 的 lower_precision 配置。
verbose_log – 如果设置为 True,启用 TensorRT 的详细日志。
timing_cache_prefix – fx2trt 使用的计时缓存文件的名称。
save_timing_cache – 如果设置为 True,使用当前的计时缓存数据更新计时缓存。
cuda_graph_batch_size – Cuda 图批量大小,默认为 -1。
dynamic_batch – 批量维度 (dim=0) 是否为动态。
use_experimental_fx_rt – 使用下一代 TRTModule,它支持基于 Python 和 TorchScript 的执行(包括在 C++ 中)。
- 返回值
由 TensorRT 下降处理后的 torch.nn.Module。
类¶
- class torch_tensorrt.fx.TRTModule(engine=None, input_names=None, output_names=None, cuda_graph_batch_size=- 1)[源代码]¶
- class torch_tensorrt.fx.InputTensorSpec(shape: Sequence[int], dtype: dtype, device: device = device(type='cpu'), shape_ranges: List[Tuple[Sequence[int], Sequence[int], Sequence[int]]] = [], has_batch_dim: bool = True)[源代码]¶
此类包含输入张量的信息。
shape: 张量的形状。
dtype: 张量的数据类型。
- device: 张量的设备。这仅用于生成给定模型的输入,以便运行形状推理。对于 TensorRT 引擎,输入必须在 cuda 设备上。
(续 device 描述)
- shape_ranges: 如果需要动态形状(形状维度为 -1),则必须提供此字段(默认为空列表)。每个 shape_range 是一个包含三个元组的元组 ((min_input_shape), (optimized_input_shape), (max_input_shape))。每个 shape_range 用于填充一个 TensorRT 优化配置。例如,如果输入形状从 (1, 224) 变化到 (100, 224),并且我们希望对 (25, 224) 进行优化,因为它是最常见的输入形状,那么我们将 shape_ranges 设置为 ((1, 224), (25, 225), (100, 224))。
(续 shape_ranges 描述)
- has_batch_dim: 形状是否包含批量维度。如果引擎需要使用动态形状运行,则必须提供批量维度。
(续 has_batch_dim 描述)
- class torch_tensorrt.fx.TRTInterpreter(module: GraphModule, input_specs: List[InputTensorSpec], explicit_batch_dimension: bool = False, explicit_precision: bool =False, logger_level=None)[源代码]¶