get_graph_node_names¶
- torchvision.models.feature_extraction.get_graph_node_names(model: Module, tracer_kwargs: Optional[Dict[str, Any]] = None, suppress_diff_warning: bool = False, concrete_args: Optional[Dict[str, Any]] = None) Tuple[List[str], List[str]] [源代码]¶
用于按执行顺序返回节点名称的开发实用工具。参阅关于节点名称的说明,位于
create_feature_extractor()
下方。有助于查看哪些节点名称可用于特征提取。节点名称不能直接从模型代码中轻松读取的原因有两个并非所有子模块都会被追踪。`torch.nn` 中的所有模块都属于此类。
表示重复应用同一操作或叶子模块的节点会带有一个 `_{counter}` 后缀。
模型会被追踪两次:一次在训练模式下,一次在评估模式下。两次追踪得到的节点名称列表都会被返回。
有关此处使用的节点命名约定的更多详细信息,请参阅 相关副标题 在 文档中。
- 参数:
model (nn.Module) – 我们想要打印节点名称的模型
tracer_kwargs (dict, 可选) – `NodePathTracer` 的关键字参数字典(它们最终会传递给 torch.fx.Tracer)。默认情况下,它将被设置为包装所有 torchvision 操作并使其成为叶子节点:{“autowrap_modules”: (math, torchvision.ops,),”leaf_modules”: _get_leaf_modules_for_ops(),} 警告:如果用户提供了 `tracer_kwargs`,上述默认参数将附加到用户提供的字典中。
suppress_diff_warning (bool, 可选) – 当图的训练版本和评估版本存在差异时,是否抑制警告。默认为 False。
concrete_args (Optional[Dict[str, any]]) – 不应被视为 Proxies 的具体参数。根据 Pytorch 文档,此参数的 API 可能无法保证。
- 返回:
一个列表,包含在训练模式下追踪模型得到的节点名称;另一个列表,包含在评估模式下追踪模型得到的节点名称。
- 返回类型:
示例
>>> model = torchvision.models.resnet18() >>> train_nodes, eval_nodes = get_graph_node_names(model)