快捷方式

get_graph_node_names

torchvision.models.feature_extraction.get_graph_node_names(model: Module, tracer_kwargs: Optional[Dict[str, Any]] = None, suppress_diff_warning: bool = False, concrete_args: Optional[Dict[str, Any]] = None) Tuple[List[str], List[str]][source]

开发实用程序,用于按执行顺序返回节点名称。请参阅create_feature_extractor() 下的节点名称说明。对于查看哪些节点名称可用于特征提取非常有用。无法直接从模型代码中轻松读取节点名称有两个原因

  1. 并非所有子模块都被跟踪。来自torch.nn 的模块都属于此类别。

  2. 表示对同一操作或叶模块重复应用的节点会得到_{counter} 后缀。

该模型被跟踪两次:一次在训练模式下,一次在评估模式下。返回两组节点名称。

有关此处使用的节点命名约定的更多详细信息,请参阅文档 中的相关子标题

参数:
  • model (nn.Module) – 我们要打印节点名称的模型

  • tracer_kwargs (dict, optional) – 用于NodePathTracer 的关键字参数字典(它们最终将传递给torch.fx.Tracer)。默认情况下,它将设置为包装并使所有 torchvision 操作成为叶节点:{“autowrap_modules”: (math, torchvision.ops,),”leaf_modules”: _get_leaf_modules_for_ops(),} 警告:如果用户提供 tracer_kwargs,则上述默认参数将被附加到用户提供的字典中。

  • suppress_diff_warning (bool, optional) – 是否抑制当图的训练版本和评估版本之间存在差异时的警告。默认为 False。

  • concrete_args (Optional[Dict[str, any]]) – 不应被视为代理的具体参数。根据Pytorch 文档,该参数的 API 可能无法保证。

返回值:

在训练模式下跟踪模型的节点名称列表,以及在评估模式下跟踪模型的另一个节点名称列表。

返回类型:

tuple(list, list)

示例

>>> model = torchvision.models.resnet18()
>>> train_nodes, eval_nodes = get_graph_node_names(model)

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源