快捷方式

create_feature_extractor

torchvision.models.feature_extraction.create_feature_extractor(model: Module, return_nodes: Optional[Union[List[str], Dict[str, str]]] = None, train_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None, eval_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None, tracer_kwargs: Optional[Dict[str, Any]] = None, suppress_diff_warning: bool = False, concrete_args: Optional[Dict[str, Any]] = None) GraphModule[source]

创建一个新的图模块,该模块从给定模型返回中间节点,并以用户指定的字符串作为键,请求的输出作为值存储在字典中。这通过重写模型的计算图来实现,通过 FX 返回所需的节点作为输出。所有未使用的节点以及其对应的参数都将被移除。

所需的输出节点必须指定为 . 分隔的路径,该路径从顶级模块向下遍历模块层次结构,直至叶子操作或叶子模块。有关此处使用的节点命名约定的更多详细信息,请参阅 相关子标题,该子标题位于 文档 中。

并非所有模型都可以通过 FX 追踪,但通过一些调整,它们可以协同工作。以下是一些(非详尽的)技巧列表

  • 如果您不需要追踪特定的、有问题的子模块,请通过传递 leaf_modules 列表作为 tracer_kwargs 之一(请参见下面的示例),将其转换为“叶子模块”。它不会被追踪,而是生成的图将保留对该模块 forward 方法的引用。

  • 同样,您可以通过传递 autowrap_functions 列表作为 tracer_kwargs 之一(请参见下面的示例),将函数转换为叶子函数。

  • 一些内置的 Python 函数可能会有问题。例如,int 将在追踪期间引发错误。您可以将它们包装在您自己的函数中,然后将该函数作为 tracer_kwargs 之一传递到 autowrap_functions 中。

有关 FX 的更多信息,请参阅 torch.fx 文档

参数:
  • model (nn.Module) – 我们将在其上提取特征的模型

  • return_nodes (listdict, 可选) – 包含节点名称(或部分名称 - 请参阅上面的注释)的 ListDict,将返回这些节点的激活值。如果它是 Dict,则键是节点名称,值是图模块返回的字典的用户指定键。如果它是 List,则将其视为将节点规范字符串直接映射到输出名称的 Dict。如果指定了 train_return_nodeseval_return_nodes,则不应指定此参数。

  • train_return_nodes (listdict, 可选) – 类似于 return_nodes。如果训练模式的返回节点与评估模式的返回节点不同,则可以使用此参数。如果指定了此参数,则还必须指定 eval_return_nodes,并且不应指定 return_nodes

  • eval_return_nodes (listdict, 可选) – 类似于 return_nodes。如果训练模式的返回节点与评估模式的返回节点不同,则可以使用此参数。如果指定了此参数,则还必须指定 train_return_nodes,并且不应指定 return_nodes

  • tracer_kwargs (dict, 可选) – NodePathTracer 的关键字参数字典(它将这些参数传递给其父类 torch.fx.Tracer)。默认情况下,它将被设置为包装所有 torchvision 操作并使其成为叶子节点:{“autowrap_modules”: (math, torchvision.ops,),”leaf_modules”: _get_leaf_modules_for_ops(),} 警告:如果用户提供了 tracer_kwargs,则上述默认参数将附加到用户提供的字典中。

  • suppress_diff_warning (bool, 可选) – 是否在训练和评估版本的图之间存在差异时抑制警告。默认为 False。

  • concrete_args (Optional[Dict[str, any]]) – 不应视为 Proxies 的具体参数。根据 Pytorch 文档,此参数的 API 可能无法保证。

示例

>>> # Feature extraction with resnet
>>> model = torchvision.models.resnet18()
>>> # extract layer1 and layer3, giving as names `feat1` and feat2`
>>> model = create_feature_extractor(
>>>     model, {'layer1': 'feat1', 'layer3': 'feat2'})
>>> out = model(torch.rand(1, 3, 224, 224))
>>> print([(k, v.shape) for k, v in out.items()])
>>>     [('feat1', torch.Size([1, 64, 56, 56])),
>>>      ('feat2', torch.Size([1, 256, 14, 14]))]

>>> # Specifying leaf modules and leaf functions
>>> def leaf_function(x):
>>>     # This would raise a TypeError if traced through
>>>     return int(x)
>>>
>>> class LeafModule(torch.nn.Module):
>>>     def forward(self, x):
>>>         # This would raise a TypeError if traced through
>>>         int(x.shape[0])
>>>         return torch.nn.functional.relu(x + 4)
>>>
>>> class MyModule(torch.nn.Module):
>>>     def __init__(self):
>>>         super().__init__()
>>>         self.conv = torch.nn.Conv2d(3, 1, 3)
>>>         self.leaf_module = LeafModule()
>>>
>>>     def forward(self, x):
>>>         leaf_function(x.shape[0])
>>>         x = self.conv(x)
>>>         return self.leaf_module(x)
>>>
>>> model = create_feature_extractor(
>>>     MyModule(), return_nodes=['leaf_module'],
>>>     tracer_kwargs={'leaf_modules': [LeafModule],
>>>                    'autowrap_functions': [leaf_function]})

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得您的问题解答

查看资源