create_feature_extractor¶
- torchvision.models.feature_extraction.create_feature_extractor(model: Module, return_nodes: Optional[Union[List[str], Dict[str, str]]] = None, train_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None, eval_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None, tracer_kwargs: Optional[Dict[str, Any]] = None, suppress_diff_warning: bool = False, concrete_args: Optional[Dict[str, Any]] = None) GraphModule [源]¶
创建一个新的图模块,该模块将给定模型中的中间节点作为字典返回,其中用户指定的键为字符串,请求的输出作为值。这通过使用 FX 重写模型的计算图来实现,从而将所需的节点作为输出返回。所有未使用的节点及其相应的参数都将被移除。
所需的输出节点必须指定为通过
.
分隔的路径,该路径从顶层模块向下遍历模块层次结构,直到叶操作或叶模块。有关此处使用的节点命名约定的更多详细信息,请参阅文档中的相关小标题。并非所有模型都可以通过 FX 进行跟踪,尽管经过一些调整后它们可以协同工作。以下是一些(非详尽的)技巧列表
如果您不需要跟踪某个特定、有问题的子模块,可以通过将
leaf_modules
列表作为tracer_kwargs
之一传递来将其转换为“叶模块”(参见下面的示例)。它不会被跟踪,而是生成的图将保留对该模块 forward 方法的引用。同样,您可以通过将
autowrap_functions
列表作为tracer_kwargs
之一传递来将函数转换为叶函数(参见下面的示例)。一些内置的 Python 函数可能会有问题。例如,
int
在跟踪期间会引发错误。您可以将它们包装在您自己的函数中,然后将其作为tracer_kwargs
之一传递给autowrap_functions
。
有关 FX 的更多信息,请参阅torch.fx 文档。
- 参数:
model (nn.Module) – 我们将从中提取特征的模型
return_nodes (list or dict, optional) – 一个
List
或Dict
,包含将返回其激活值的节点的名称(或部分名称 - 见上方注释)。如果它是一个Dict
,键是节点名称,值是图模块返回字典的用户指定键。如果它是一个List
,它被视为一个Dict
,将节点规范字符串直接映射到输出名称。如果指定了train_return_nodes
和eval_return_nodes
,则不应指定此参数。train_return_nodes (list or dict, optional) – 类似于
return_nodes
。如果在训练模式下的返回节点与评估模式下的不同,则可以使用此参数。如果指定了此参数,则必须同时指定eval_return_nodes
,并且不应指定return_nodes
。eval_return_nodes (list or dict, optional) – 类似于
return_nodes
。如果在训练模式下的返回节点与评估模式下的不同,则可以使用此参数。如果指定了此参数,则必须同时指定train_return_nodes
,并且不应指定 return_nodes。tracer_kwargs (dict, optional) –
NodePathTracer
的关键字参数字典(NodePathTracer
会将其传递给其父类 torch.fx.Tracer)。默认情况下,它将设置为包装所有 torchvision 算子并使其成为叶节点:{“autowrap_modules”: (math, torchvision.ops,),”leaf_modules”: _get_leaf_modules_for_ops(),} 警告:如果用户提供了 tracer_kwargs,上述默认参数将被附加到用户提供的字典中。suppress_diff_warning (bool, optional) – 当训练和评估版本的图之间存在差异时,是否抑制警告。默认为 False。
concrete_args (Optional[Dict[str, any]]) – 不应被视为 Proxies 的具体参数。根据PyTorch 文档,此参数的 API 可能无法保证。
示例
>>> # Feature extraction with resnet >>> model = torchvision.models.resnet18() >>> # extract layer1 and layer3, giving as names `feat1` and feat2` >>> model = create_feature_extractor( >>> model, {'layer1': 'feat1', 'layer3': 'feat2'}) >>> out = model(torch.rand(1, 3, 224, 224)) >>> print([(k, v.shape) for k, v in out.items()]) >>> [('feat1', torch.Size([1, 64, 56, 56])), >>> ('feat2', torch.Size([1, 256, 14, 14]))] >>> # Specifying leaf modules and leaf functions >>> def leaf_function(x): >>> # This would raise a TypeError if traced through >>> return int(x) >>> >>> class LeafModule(torch.nn.Module): >>> def forward(self, x): >>> # This would raise a TypeError if traced through >>> int(x.shape[0]) >>> return torch.nn.functional.relu(x + 4) >>> >>> class MyModule(torch.nn.Module): >>> def __init__(self): >>> super().__init__() >>> self.conv = torch.nn.Conv2d(3, 1, 3) >>> self.leaf_module = LeafModule() >>> >>> def forward(self, x): >>> leaf_function(x.shape[0]) >>> x = self.conv(x) >>> return self.leaf_module(x) >>> >>> model = create_feature_extractor( >>> MyModule(), return_nodes=['leaf_module'], >>> tracer_kwargs={'leaf_modules': [LeafModule], >>> 'autowrap_functions': [leaf_function]})