编写 Dynamo ATen Lowering Passes¶
降低通道的基础知识¶
ATen lowering passes 是 Python 函数,它将 ATen 运算符图作为输入,应用一些所需的修改,例如运算符合并/融合、运算符替换、子图重写、自定义运算符插入或对 torch.fx.GraphModule 的其他操作,然后将修改后的图返回给调用者。这些 lowering passes 通常就地修改图形并返回相同的输入对象。
降低通道要求¶
Torch-TRT 中的 ATen lowering pass 函数必须满足两个要求: - 该函数必须将 torch.fx.GraphModule 和 torch Tensors 序列 Sequence[torch.Tensor] 作为输入,并返回降低后的 torch.fx.GraphModule - 该函数必须使图形保持有效且可调用的状态,包括执行任何必要的 linting 和重新编译
有关 FX 中 图形操作 的信息,请参阅此链接。 有关修复具有输入同时也是输出的图形的 lowering pass 示例(TRT 引擎不允许的配置),请参见下文。
示例 Lowering Pass¶
def repair_input_as_output(gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor]) -> torch.fx.GraphModule:
"""Repair scenarios where inputs are also outputs of the graph
TRT does not allow such cases, so we insert a clone (identity) layer
"""
modified_graph = False
# Extract graph placeholder Tensors
placeholders = [
node
for node in gm.graph.nodes
if (
node.op == "placeholder"
and isinstance(node.type, type)
and issubclass(node.type, torch.Tensor)
)
]
for placeholder in placeholders:
# If any placeholder has any users which are direct graph outputs
if len(placeholder.users) >= 1 and any(
user.op == "output" for user in placeholder.users
):
modified_graph = True
# Get direct graph outputs which are direct uses of placeholders
direct_outputs = [user for user in placeholder.users if user.op == "output"]
# Insert clone node for placeholder to ensure
# placeholder is not a direct output
with gm.graph.inserting_after(placeholder):
cloned_placeholder = gm.graph.call_function(
torch.ops.aten.clone.default,
args=(placeholder,),
)
# Replace placeholder as output with cloned version
for output in direct_outputs:
output.replace_input_with(placeholder, cloned_placeholder)
# If the graph was modified, clean up the graph and ensure it is up-to-date
if modified_graph:
gm.graph.eliminate_dead_code()
gm.graph.lint()
gm.recompile()
logger.debug(f"Graph after repair_input_as_output:\n{gm.graph}")
return gm
注册 Lowering Passes¶
Lowering passes 当前在 py/torch_tensorrt/dynamo/lowering/passes/__init__.py 中注册,使用 torch.fx.passes.pass_manager.PassManager 实用程序以所需的顺序组装 passes 列表。 直接添加到该列表的新 passes 将应用于 Torch-TensorRT torch.compile 后端中的图形。 目前,我们为方便起见提供了一个 ATen lowering pass 注册装饰器,可以直接调用该装饰器,也可以使用可选的 index 关键字参数来控制 lowering pass 将插入到 pass 列表中的位置。
例如,要在默认位置(列表末尾)插入 pass,可以使用以下代码
@_aten_lowering_pass
def my_custom_pass(gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor]) -> torch.fx.GraphModule:
...
或者,要在 passlist 中自定义索引(例如列表的前面)插入 pass,可以使用以下代码
@_aten_lowering_pass(index=0)
def my_custom_pass(gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor]) -> torch.fx.GraphModule:
...
在 torch_tensorrt.dynamo.lowering.passes 中还提供了实用程序,用于显示当前可用的 lowering pass 列表、将这些 passes 应用于任意 torch.fx.GraphModule 以及删除特定索引处的 lowering pass。
# Print all lowering passes in the list
print(dump_lowering_passes())
# Apply lowering passes to a GraphModule
apply_lowering_passes(graph_module, sample_inputs)
# Remove the lowering pass at index 1
_remove_lowering_pass(index=1)
注意: 上述 API 可能会发生变化,因为 lowering pass 系统会不断发展。