快捷方式

降低阶段

降低阶段由一系列遍组成,这些遍是将图从高级表示映射到低级表示的操作。每个遍执行特定的任务,例如内联方法调用。其目的是显著减少实际映射到 TensorRT 时转换阶段需要处理的内容。我们旨在实现更接近 1 对 1 的运算符转换,而不是寻找适用的子图,从而限制转换器的数量并缩小每个转换器的范围。

通过将日志级别设置为 Level::kGraph,您可以查看每个遍的效果

使用的遍

消除公共子表达式

移除图中的公共子表达式

消除死代码

死代码消除将检查节点是否具有副作用,如果具有副作用则不会删除它。

消除异常或通过模式

脚本化模块中的常见模式是维度守卫,如果输入维度与预期不符,它将抛出异常。

%1013 : bool = aten::ne(%1012, %24) # ~/.local/lib/python3.6/site-packages/torch/nn/modules/batchnorm.py:248:11
    = prim::If(%1013) # ~/.local/lib/python3.6/site-packages/torch/nn/modules/batchnorm.py:248:8
    block0():
        = prim::RaiseException(%23) # ~/.local/lib/python3.6/site-packages/torch/nn/modules/batchnorm.py:249:12
    -> ()
    block1():
    -> ()

由于我们在编译时解析所有这些内容,并且 TensorRT 图中没有异常,我们只需将其移除。

消除冗余守卫

消除输出完全由其输入确定的运算符的冗余守卫,即如果此类运算符的输入受守卫保护,则允许我们移除运算符输出上的守卫

冻结模块

冻结属性并内联常量和模块。在图中传播常量。

融合 AddMM 分支

脚本化模块中的常见模式是不同维度的张量使用不同的结构来实现线性层。我们将这些不同的变体融合成一个单一的,该单一结构将被 Unpack AddMM 遍捕获。

%ret : Tensor = prim::If(%622)
block0():
  %ret.1 : Tensor = aten::addmm(%self.fc.bias, %x9.1, %3677, %3, %3)
  -> (%ret.1)
block1():
  %output.1 : Tensor = aten::matmul(%x9.1, %3677)
  %output0.1 : Tensor = aten::add_(%output.1, %self.fc.bias, %3)
  -> (%output0.1)

我们将这组块融合成一个如下所示的图

%ret : Tensor = aten::addmm(%self.fc.bias, %x9.1, %3677, %3, %3)

融合 Linear

匹配 aten::linear 模式并将其融合成一个单一的 aten::linear。此遍将 JIT 生成的 addmm 或 matmul + add 重新融合成 linear

融合 Flatten Linear

当输入层高于 1D 时,TensorRT 会将其隐式展平为全连接层。因此,当存在 aten::flatten -> aten::linear 模式时,我们移除 aten::flatten

降低图

给定一个方法图,其第一个参数是 %self,将其降低为一个图,其中所有属性访问都被替换为图的显式输入(而不是对 %self 执行 prim::GetAttr 的结果)。返回一个元组 (graph, parameters),其中图的最后 module.parameters.size() 个输入是此方法中使用的可训练参数。其余输入是函数的真实输入。

降低元组

  • 降低简单元组:

移除 TupleConstruct 和 TupleUnpack 匹配的元组,但保留 if 语句、循环以及输入/输出中的元组

  • 降低所有元组:

移除 _所有_ 元组,如果某些元组无法移除则抛出错误。ONNX 使用此遍以确保转换前没有元组,但对输入包含元组的图无效。

模块回退

模块回退包含两个必须成对运行的降低遍。第一个遍在冻结前运行,在图中标注应在 PyTorch 中运行的模块周围的定界符。第二个遍在冻结后标记这些定界符之间的节点,以指示它们应在 PyTorch 中运行。

  • 标注模块回退

在冻结前在模块调用周围放置定界节点,以指示图中的哪些节点应在 PyTorch 中运行

  • 标记回退节点

查找定界符,然后标记定界符之间的所有节点,以告知分区应在 PyTorch 中运行它们

窥孔优化

此优化遍旨在捕获您可能感兴趣的所有小型、易于捕获的窥孔优化。

目前,它执行以下操作:
  • 消除无操作的 ‘expand’ 节点

  • 将 x.t().t() 简化为 x

移除 Contiguous

移除 contiguous 运算符,因为我们执行 TensorRT 时内存已是连续的。

移除 Dropout

移除 dropout 运算符,因为我们在执行推理。

移除 To

移除执行类型转换的 aten::to 运算符,因为 TensorRT 会自行管理。重要的是,这是最后运行的遍之一,以便其他遍有机会将所需的类型转换运算符移出主命名空间。

解包 AddMM

aten::addmm 解包为 aten::matmulaten::add_(并添加一个 trt::const 运算符以在 TensorRT 图中冻结偏置)。这使我们能够重用 aten::matmulaten::add_ 转换器,而无需专门的转换器。

解包 LogSoftmax

aten::logsoftmax 解包为 aten::softmaxaten::log。这使我们能够重用 aten::softmaxaten::log 转换器,而无需专门的转换器。

循环展开

展开兼容的循环(例如足够短的循环)的操作,以便您只需遍历循环一次。

将 Tile 替换为 Repeat

移除 dropout 运算符,因为我们在执行推理。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题解答

查看资源