用于模型检查的特征提取¶
torchvision.models.feature_extraction
包包含了特征提取工具,使我们能够利用模型来访问输入的中间变换。这对于计算机视觉中的多种应用非常有用。以下是一些示例:
可视化特征图。
提取特征以计算图像描述符,用于人脸识别、复制检测或图像检索等任务。
将选定的特征传递给下游子网络,以针对特定任务进行端到端训练。例如,将特征层次结构传递给带有目标检测头的特征金字塔网络 (Feature Pyramid Network)。
Torchvision 为此提供了 create_feature_extractor()
。它大致遵循以下步骤工作:
符号化跟踪模型,以获取其如何一步一步转换输入的图形表示。
将用户选择的图节点设置为输出。
移除所有冗余节点(输出节点下游的任何内容)。
从生成的图生成 Python 代码,并将其与图本身一起打包成一个 PyTorch 模块。
的 torch.fx 文档提供了对上述过程和符号化跟踪内部工作原理更通用、更详细的解释。
关于节点名称
为了指定哪些节点应作为提取特征的输出节点,应熟悉此处使用的节点命名约定(与 torch.fx
中使用的略有不同)。节点名称指定为以 .
分隔的路径,沿着模块层次结构从顶层模块一直到叶操作或叶模块。例如,ResNet-50 中的 "layer4.2.relu"
表示 ResNet
模块的第 4 层第 2 个块中 ReLU 的输出。以下是一些需要注意的细节:
为
create_feature_extractor()
指定节点名称时,您可以提供节点名称的缩写版本作为快捷方式。要了解其工作原理,可以尝试创建一个 ResNet-50 模型,并使用train_nodes, _ = get_graph_node_names(model) print(train_nodes)
打印节点名称,观察与layer4
相关的最后一个节点是"layer4.2.relu_2"
。您可以指定"layer4.2.relu_2"
作为返回节点,或者仅指定"layer4"
,因为按照约定,这指的是layer4
中(按执行顺序)的最后一个节点。如果某个模块或操作重复多次,节点名称会获得额外的
_{int}
后缀以消除歧义。例如,如果加法 (+
) 操作在同一个forward
方法中使用了三次。那么将会有"path.to.module.add"
、"path.to.module.add_1"
、"path.to.module.add_2"
。计数器在直接父级的范围内维护。因此,在 ResNet-50 中有"layer4.1.add"
和"layer4.2.add"
。由于加法操作位于不同的块中,因此无需后缀来消除歧义。
示例
以下是关于如何为 MaskRCNN 提取特征的示例
import torch
from torchvision.models import resnet50
from torchvision.models.feature_extraction import get_graph_node_names
from torchvision.models.feature_extraction import create_feature_extractor
from torchvision.models.detection.mask_rcnn import MaskRCNN
from torchvision.models.detection.backbone_utils import LastLevelMaxPool
from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork
# To assist you in designing the feature extractor you may want to print out
# the available nodes for resnet50.
m = resnet50()
train_nodes, eval_nodes = get_graph_node_names(resnet50())
# The lists returned, are the names of all the graph nodes (in order of
# execution) for the input model traced in train mode and in eval mode
# respectively. You'll find that `train_nodes` and `eval_nodes` are the same
# for this example. But if the model contains control flow that's dependent
# on the training mode, they may be different.
# To specify the nodes you want to extract, you could select the final node
# that appears in each of the main layers:
return_nodes = {
# node_name: user-specified key for output dict
'layer1.2.relu_2': 'layer1',
'layer2.3.relu_2': 'layer2',
'layer3.5.relu_2': 'layer3',
'layer4.2.relu_2': 'layer4',
}
# But `create_feature_extractor` can also accept truncated node specifications
# like "layer1", as it will just pick the last node that's a descendent of
# of the specification. (Tip: be careful with this, especially when a layer
# has multiple outputs. It's not always guaranteed that the last operation
# performed is the one that corresponds to the output you desire. You should
# consult the source code for the input model to confirm.)
return_nodes = {
'layer1': 'layer1',
'layer2': 'layer2',
'layer3': 'layer3',
'layer4': 'layer4',
}
# Now you can build the feature extractor. This returns a module whose forward
# method returns a dictionary like:
# {
# 'layer1': output of layer 1,
# 'layer2': output of layer 2,
# 'layer3': output of layer 3,
# 'layer4': output of layer 4,
# }
create_feature_extractor(m, return_nodes=return_nodes)
# Let's put all that together to wrap resnet50 with MaskRCNN
# MaskRCNN requires a backbone with an attached FPN
class Resnet50WithFPN(torch.nn.Module):
def __init__(self):
super(Resnet50WithFPN, self).__init__()
# Get a resnet50 backbone
m = resnet50()
# Extract 4 main layers (note: MaskRCNN needs this particular name
# mapping for return nodes)
self.body = create_feature_extractor(
m, return_nodes={f'layer{k}': str(v)
for v, k in enumerate([1, 2, 3, 4])})
# Dry run to get number of channels for FPN
inp = torch.randn(2, 3, 224, 224)
with torch.no_grad():
out = self.body(inp)
in_channels_list = [o.shape[1] for o in out.values()]
# Build FPN
self.out_channels = 256
self.fpn = FeaturePyramidNetwork(
in_channels_list, out_channels=self.out_channels,
extra_blocks=LastLevelMaxPool())
def forward(self, x):
x = self.body(x)
x = self.fpn(x)
return x
# Now we can build our model!
model = MaskRCNN(Resnet50WithFPN(), num_classes=91).eval()
API 参考¶
|
创建一个新的图模块,该模块将给定模型中的中间节点以字典形式返回,其中用户指定的键作为字符串,请求的输出作为值。 |
|
返回按执行顺序排列的节点名称的开发工具。 |