快捷方式

降低阶段

降低阶段由传递组成,传递是将图从高级表示映射到低级表示的操作。每个传递都执行特定操作,例如内联方法调用。目的是显着减少转换阶段在实际映射到 TensorRT 时需要处理的内容。我们的目标是更接近 1->1 的操作转换,而不是寻找适用的子图,从而限制转换器的数量并缩小每个转换器的范围。

可以通过将日志级别设置为 Level::kGraph 来查看每个传递的效果

使用的传递

消除公共子表达式

删除图中的公共子表达式

消除死代码

死代码消除将检查节点是否具有副作用,如果具有副作用则不会删除它。

消除异常或传递模式

脚本模块中常见的模式是维度保护,如果输入维度不符合预期,它将抛出异常。

%1013 : bool = aten::ne(%1012, %24) # ~/.local/lib/python3.6/site-packages/torch/nn/modules/batchnorm.py:248:11
    = prim::If(%1013) # ~/.local/lib/python3.6/site-packages/torch/nn/modules/batchnorm.py:248:8
    block0():
        = prim::RaiseException(%23) # ~/.local/lib/python3.6/site-packages/torch/nn/modules/batchnorm.py:249:12
    -> ()
    block1():
    -> ()

由于我们在编译时解析所有这些内容,并且 TensorRT 图中没有异常,因此我们只需将其删除。

消除冗余保护

消除操作的冗余保护,这些操作的输出完全由其输入决定,即如果此类操作的输入受到保护,我们允许删除对操作输出的保护

冻结模块

冻结属性并内联常量和模块。传播图中的常量。

融合 AddMM 分支

脚本模块中常见的模式是不同维度的张量使用不同的构造来实现线性层。我们将这些不同的变体融合成一个变体,该变体将被 Unpack AddMM 传递捕获。

%ret : Tensor = prim::If(%622)
block0():
  %ret.1 : Tensor = aten::addmm(%self.fc.bias, %x9.1, %3677, %3, %3)
  -> (%ret.1)
block1():
  %output.1 : Tensor = aten::matmul(%x9.1, %3677)
  %output0.1 : Tensor = aten::add_(%output.1, %self.fc.bias, %3)
  -> (%output0.1)

我们将这组块融合成一个类似于此的图

%ret : Tensor = aten::addmm(%self.fc.bias, %x9.1, %3677, %3, %3)

融合线性

匹配 aten::linear 模式并将其融合成单个 aten::linear 此传递将由 JIT 生成的 addmm 或 matmul + add 融合回线性

融合扁平化线性

当输入层高于 1D 时,TensorRT 会隐式地将输入层扁平化为全连接层。因此,当存在 aten::flatten -> aten::linear 模式时,我们将删除 aten::flatten

降低图

给定一个图,该图包含一个方法,该方法的第一个参数是 %self,将其降低为一个图,其中所有属性访问都替换为图的显式输入(而不是在 %self 上执行的 prim::GetAttr 的结果)。返回一个元组 (图,参数),其中最后一个模块。parameters.size() 图的输入是此方法中使用的可训练参数。其余输入是函数的真实输入。

降低元组

  • LowerSimpleTuples:

删除元组,其中 TupleConstruct 和 TupleUnpack 匹配,但保留 if 语句、循环和作为输入/输出的元组

  • LowerAllTuples:

删除_所有_元组,如果某些元组无法删除则引发错误,这由 ONNX 用于确保转换前没有元组,但对输入包含元组的图不起作用。

模块回退

模块回退包含两个降低传递,这两个传递必须作为一对运行。第一个传递在冻结之前运行,以在图中放置分隔符,围绕应该在 PyTorch 中运行的模块。第二个传递在冻结后标记这些分隔符之间的节点,以表示它们应该在 PyTorch 中运行。

  • NotateModuleForFallback

在冻结之前在模块调用周围放置分隔节点,以表示图中的哪些节点应该在 PyTorch 中运行

  • MarkNodesForFallback

查找分隔符,然后标记分隔符之间的所有节点,以告知分区在 PyTorch 中运行它们

窥孔优化

此优化传递的目的是捕获所有您可能感兴趣的简单易捕获的窥孔优化。

目前,它执行以下操作
  • 消除无操作“扩展”节点

  • 将 x.t().t() 简化为 x

删除 Contiguous

删除连续操作符,因为我们正在进行 TensorRT 内存已经连续。

删除 Dropout

删除 dropout 操作符,因为我们正在进行推理。

删除 To

删除执行强制转换的 aten::to 操作符,因为 TensorRT 自行管理它。重要的是,这是运行的最后一个传递之一,以便其他传递有机会将所需的强制转换操作符移出主命名空间。

解包 AddMM

aten::addmm 解包为 aten::matmulaten::add_(以及一个额外的 trt::const 操作,用于冻结 TensorRT 图中的偏差)。这使我们能够重用 aten::matmulaten::add_ 转换器,而不必需要专用转换器。

解包 LogSoftmax

aten::logsoftmax 解包为 aten::softmaxaten::log。这使我们能够重复使用 aten::softmaxaten::log 转换器,而无需专门的转换器。

展开循环

展开兼容循环(例如,足够短的循环)的操作,以便您只需遍历循环一次。

用 Repeat 替换 Tile

删除 dropout 操作符,因为我们正在进行推理。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源