引言
基于 FX 的特征提取是 TorchVision 的一项新工具,它允许我们在 PyTorch 模块的前向传播过程中获取输入的中间转换结果。该工具通过符号追踪(Symbolically tracing)前向传播方法,生成一个图(Graph),其中每个节点代表一个单一的操作。节点的命名方式符合人类阅读习惯,因此可以轻松指定想要获取的节点。
这听起来有点复杂吗?别担心,这篇文章适合每一位读者。无论你是初学者还是资深的深度视觉开发者,都有必要了解 FX 特征提取。如果你想了解更多关于特征提取的背景知识,请继续往下读。如果你已经很熟悉这些概念,只想知道如何在 PyTorch 中实现,可以直接跳转到“PyTorch 中现有的方法:优缺点”。如果你已经了解在 PyTorch 中进行特征提取的挑战,请随意跳到“FX 来拯救”。
特征提取回顾
我们都习惯了深度神经网络 (DNN) 接收输入并产生输出的模式,而未必会去思考中间发生了什么。让我们以 ResNet-50 分类模型为例:

图 1:ResNet-50 将一张鸟的图片转化为抽象概念“鸟”。来源:来自 ImageNet 的鸟类图片。
然而,我们知道 ResNet-50 架构中有许多顺序排列的“层”,它们一步步地转换输入。在下方的图 2 中,我们深入底层展示了 ResNet-50 内部的层,同时也展示了输入在通过这些层时产生的中间转换。

图 2:ResNet-50 分多个步骤转换输入图像。从概念上讲,我们可以在每个步骤之后获取图像的中间转换结果。来源:来自 ImageNet 的鸟类图片。
PyTorch 中现有的方法:优缺点
在基于 FX 的特征提取引入之前,PyTorch 中已经有几种进行特征提取的方法。
为了说明这些方法,我们考虑一个简单的卷积神经网络:
- 应用几个“块”(blocks),每个块中包含几个卷积层。
- 在经过几个块之后,使用全局平均池化和展平操作。
- 最后使用一个输出分类层。
import torch
from torch import nn
class ConvBlock(nn.Module):
"""
Applies `num_layers` 3x3 convolutions each followed by ReLU then downsamples
via 2x2 max pool.
"""
def __init__(self, num_layers, in_channels, out_channels):
super().__init__()
self.convs = nn.ModuleList(
[nn.Sequential(
nn.Conv2d(in_channels if i==0 else out_channels, out_channels, 3, padding=1),
nn.ReLU()
)
for i in range(num_layers)]
)
self.downsample = nn.MaxPool2d(kernel_size=2, stride=2)
def forward(self, x):
for conv in self.convs:
x = conv(x)
x = self.downsample(x)
return x
class CNN(nn.Module):
"""
Applies several ConvBlocks each doubling the number of channels, and
halving the feature map size, before taking a global average and classifying.
"""
def __init__(self, in_channels, num_blocks, num_classes):
super().__init__()
first_channels = 64
self.blocks = nn.ModuleList(
[ConvBlock(
2 if i==0 else 3,
in_channels=(in_channels if i == 0 else first_channels*(2**(i-1))),
out_channels=first_channels*(2**i))
for i in range(num_blocks)]
)
self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
self.cls = nn.Linear(first_channels*(2**(num_blocks-1)), num_classes)
def forward(self, x):
for block in self.blocks:
x = block(x)
x = self.global_pool(x)
x = x.flatten(1)
x = self.cls(x)
return x
model = CNN(3, 4, 10)
out = model(torch.zeros(1, 3, 32, 32)) # This will be the final logits over classes
假设我们想要获取全局平均池化之前的最终特征图,我们可以这样做:
修改 forward 方法
def forward(self, x):
for block in self.blocks:
x = block(x)
self.final_feature_map = x
x = self.global_pool(x)
x = x.flatten(1)
x = self.cls(x)
return x
或者直接返回它。
def forward(self, x):
for block in self.blocks:
x = block(x)
final_feature_map = x
x = self.global_pool(x)
x = x.flatten(1)
x = self.cls(x)
return x, final_feature_map
这看起来很简单。但它存在一些缺点,所有缺点都源于同一个潜在问题:修改源代码并非理想方案。
- 考虑到项目的实际情况,获取并修改源代码并不总是那么容易。
- 如果我们想要灵活性(开启或关闭特征提取,或者对其进行变体修改),我们需要进一步调整源代码以支持这些需求。
- 这并不总是插入一行代码那么简单。想想如果按照我编写此模块的方式,你该如何获取中间某个块的特征图。
- 总的来说,当不需要改变模型工作方式时,我们更倾向于避免维护模型源代码的额外负担。
人们可以看出,当处理更大、更复杂的模型,并试图从嵌套子模块中获取特征时,这种缺点会变得更加棘手。
使用原始模块的参数编写一个新模块
沿用上面的例子,假设我们想从每个块中获取特征图。我们可以编写一个新模块,如下所示:
class CNNFeatures(nn.Module):
def __init__(self, backbone):
super().__init__()
self.blocks = backbone.blocks
def forward(self, x):
feature_maps = []
for block in self.blocks:
x = block(x)
feature_maps.append(x)
return feature_maps
backbone = CNN(3, 4, 10)
model = CNNFeatures(backbone)
out = model(torch.zeros(1, 3, 32, 32)) # This is now a list of Tensors, each representing a feature map
事实上,这与 TorchVision 内部构建许多检测模型所采用的方法非常相似。
虽然这种方法解决了一些直接修改源代码的问题,但仍存在一些主要缺点:
- 通常只能直接访问顶层子模块的输出。处理嵌套子模块会变得非常复杂。
- 我们必须小心,不能遗漏输入和输出之间的任何重要操作。在将原始模块的确切功能转录到新模块时,会引入潜在的错误。
总的来说,这种方法和前一种方法都存在将特征提取与模型源代码本身耦合的复杂性。事实上,如果我们检查 TorchVision 模型的源代码,我们可能会怀疑一些设计选择是为了方便在下游任务中以这种方式使用它们。
使用钩子 (Hooks)
Hooks 将我们从编写源代码的范式,转向了指定输出的范式。考虑我们上面的玩具 CNN 例子,以及获取每一层特征图的目标,我们可以像这样使用 Hooks:
model = CNN(3, 4, 10)
feature_maps = [] # This will be a list of Tensors, each representing a feature map
def hook_feat_map(mod, inp, out):
feature_maps.append(out)
for block in model.blocks:
block.register_forward_hook(hook_feat_map)
out = model(torch.zeros(1, 3, 32, 32)) # This will be the final logits over classes
现在我们在访问嵌套子模块方面拥有了完全的灵活性,并且免去了摆弄源代码的责任。但这种方法也有其自身的缺点:
- 我们只能将 Hooks 应用于模块。如果我们有需要获取输出的函数式操作(reshape、view、函数式非线性等),Hooks 将无法直接作用于它们。
- 我们没有修改源代码,因此无论 Hooks 如何设置,整个前向传播过程都会执行。如果我们只需要获取早期特征,而不需要最终输出,这可能会导致大量无用的计算。
- Hooks 不支持 TorchScript。
以下是不同方法的优缺点总结:
| 可以在不进行任何修改或重写的情况下使用源代码 | 获取特征具有完全的灵活性 | 丢弃不必要的计算步骤 | 支持 TorchScript | |
|---|---|---|---|---|
| 修改 forward 方法 | 否 | 技术上说是。取决于你愿意写多少代码。所以在实践中,否。 | 是 | 是 |
| 重用原始模块子模块/参数的新模块 | 否 | 技术上说是。取决于你愿意写多少代码。所以在实践中,否。 | 是 | 是 |
| Hooks | 是 | 基本是。仅限子模块输出 | 否 | 否 |
表 1:PyTorch 中几种现有特征提取方法的优缺点
在本文的下一节中,让我们看看如何实现全面的“是”。
FX 来拯救
对于 Python 和编程的初学者来说,此时最自然的问题可能是:“难道我们不能直接指向一行代码,告诉 Python 或 PyTorch 我们想要那行代码的结果吗?”对于那些花更多时间编码的人来说,原因很明确:一行代码中可能发生多个操作,无论是显式编写的,还是作为子操作隐含的。仅以这个简单的模块为例:
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.param = torch.nn.Parameter(torch.rand(3, 4))
self.submodule = MySubModule()
def forward(self, x):
return self.submodule(x + self.param).clamp(min=0.0, max=1.0)
forward 方法有一行代码,我们可以将其拆解为:
- 将
self.param加到x上。 - 通过
self.submodule传递 x。这里我们需要考虑该子模块中发生的步骤。为了说明,我将使用虚拟操作名称:I. submodule.op_1 II. submodule.op_2 - 应用 clamp 操作。
因此,即使我们指向这一行,问题依然是:“我们要提取哪个步骤的输出?”
FX 是一个核心的 PyTorch 工具包,它(简单来说)完成了我刚才提到的拆解工作。它执行一种称为“符号追踪”的操作,这意味着 Python 代码会被解释,并使用某种真实输入的虚拟代理(proxy)逐个操作地执行。引入一些术语:上面描述的每个步骤都被视为一个“节点”(Node),连续的节点连接在一起形成一个“图”(Graph)。以下是上述“步骤”转化为这种图概念的形式:

图 3:对我们的简单 forward 方法示例进行符号追踪结果的图形表示。
注意,我们称之为“图”,而不仅仅是一系列步骤,因为图可能会分支和重组。想想残差块中的跳跃连接(skip connection),它看起来像这样:

图 4:残差跳跃连接的图形表示。中间的节点就像残差块的主分支,最终节点代表主分支输入和输出的总和。
现在,TorchVision 的 get_graph_node_names 函数应用了上述 FX,并在过程中为每个节点标记了一个人类可读的名称。让我们用上一节的玩具 CNN 模型尝试一下:
model = CNN(3, 4, 10)
from torchvision.models.feature_extraction import get_graph_node_names
nodes, _ = get_graph_node_names(model)
print(nodes)
结果如下:
['x', 'blocks.0.convs.0.0', 'blocks.0.convs.0.1', 'blocks.0.convs.1.0', 'blocks.0.convs.1.1', 'blocks.0.downsample', 'blocks.1.convs.0.0', 'blocks.1.convs.0.1', 'blocks.1.convs.1.0', 'blocks.1.convs.1.1', 'blocks.1.convs.2.0', 'blocks.1.convs.2.1', 'blocks.1.downsample', 'blocks.2.convs.0.0', 'blocks.2.convs.0.1', 'blocks.2.convs.1.0', 'blocks.2.convs.1.1', 'blocks.2.convs.2.0', 'blocks.2.convs.2.1', 'blocks.2.downsample', 'blocks.3.convs.0.0', 'blocks.3.convs.0.1', 'blocks.3.convs.1.0', 'blocks.3.convs.1.1', 'blocks.3.convs.2.0', 'blocks.3.convs.2.1', 'blocks.3.downsample', 'global_pool', 'flatten', 'cls']
我们可以将这些节点名称理解为感兴趣操作的层次化“地址”。例如,‘blocks.1.downsample’ 指的是第二个 ConvBlock 中的 MaxPool2d 层。
create_feature_extractor(魔法发生的地方)比 get_graph_node_names 更进一步。它将所需的节点名称作为输入参数之一,然后使用更多 FX 核心功能来:
- 将所需的节点分配为输出。
- 修剪掉不必要的下游节点及其相关参数。
- 将生成的图转换回 Python 代码。
- 返回另一个 PyTorch 模块给用户。该模块的 forward 方法就是第 3 步生成的 python 代码。
作为一个演示,这是我们如何应用 create_feature_extractor 从我们的玩具 CNN 模型中获取 4 个特征图:
from torchvision.models.feature_extraction import create_feature_extractor
# Confused about the node specification here?
# We are allowed to provide truncated node names, and `create_feature_extractor`
# will choose the last node with that prefix.
feature_extractor = create_feature_extractor(
model, return_nodes=['blocks.0', 'blocks.1', 'blocks.2', 'blocks.3'])
# `out` will be a dict of Tensors, each representing a feature map
out = feature_extractor(torch.zeros(1, 3, 32, 32))
就是这么简单。归根结底,FX 特征提取只是让我们能够实现一些人在刚开始编程时天真地希望做到的事情:“直接把这段代码的输出给我(指着屏幕)”。
- ……不需要摆弄源代码。
- ……在访问输入的任何中间转换方面提供了完全的灵活性,无论是模块的输出还是函数操作。
- ……在提取特征后确实丢弃了不必要的计算步骤。
- ……而且我之前没提到过,它也支持 TorchScript!
这是加上 FX 特征提取后的表格:
| 可以在不进行任何修改或重写的情况下使用源代码 | 获取特征具有完全的灵活性 | 丢弃不必要的计算步骤 | 支持 TorchScript | |
|---|---|---|---|---|
| 修改 forward 方法 | 否 | 技术上说是。取决于你愿意写多少代码。所以在实践中,否。 | 是 | 是 |
| 重用原始模块子模块/参数的新模块 | 否 | 技术上说是。取决于你愿意写多少代码。所以在实践中,否。 | 是 | 是 |
| Hooks | 是 | 基本是。仅限子模块输出 | 否 | 否 |
| FX | 是 | 是 | 是 | 是 |
表 2:表 1 的副本,增加了 FX 特征提取行。FX 特征提取在所有方面都得到了“是”!
当前 FX 的局限性
虽然我很想在此时结束文章,但 FX 确实有一些局限性,归结为:
- 在解释和转换为图的步骤中,可能存在一些 FX 尚未处理的 Python 代码。
- 动态控制流无法表示为静态图。
当这些问题出现时,最简单的做法是将底层代码打包成一个“叶节点”(leaf node)。还记得图 3 的示例图吗?从概念上讲,我们可以同意将 submodule 本身视为一个节点,而不是一组代表底层操作的节点。如果我们这样做,我们可以将图重绘为:

图 5:如果将 submodule 视为“叶”节点,则其中的各个操作(左 – 红色框内)可以合并为一个节点(右 – 节点 #2)。
如果子模块中存在一些有问题的代码,但我们不需要从中提取任何中间转换,我们就希望这样做。在实践中,通过为 create_feature_extractor 或 get_graph_node_names 提供关键字参数即可轻松实现。
model = CNN(3, 4, 10)
nodes, _ = get_graph_node_names(model, tracer_kwargs={'leaf_modules': [ConvBlock]})
print(nodes)
输出如下:
['x', 'blocks.0', 'blocks.1', 'blocks.2', 'blocks.3', 'global_pool', 'flatten', 'cls']
请注意,与之前相比,任何给定的 ConvBlock 的所有节点都被合并为一个单一节点。
我们也可以对函数做类似的事情。例如,Python 内置的 len 需要被封装,结果应被视为叶节点。以下是你可以使用核心 FX 功能实现的方法:
torch.fx.wrap('len')
class MyModule(nn.Module):
def forward(self, x):
x += 1
len(x)
model = MyModule()
feature_extractor = create_feature_extractor(model, return_nodes=['add'])
对于你定义的函数,你可以使用 create_feature_extractor 的另一个关键字参数(小细节:为什么你可能想以这种方式操作)。
def myfunc(x):
return len(x)
class MyModule(nn.Module):
def forward(self, x):
x += 1
myfunc(x)
model = MyModule()
feature_extractor = create_feature_extractor(
model, return_nodes=['add'], tracer_kwargs={'autowrap_functions': [myfunc]})
请注意,上述修复方法都没有涉及到修改源代码。
当然,有时你试图访问的中间转换恰好位于导致问题的同一个 forward 方法或函数中。此时,我们不能简单地将该模块或函数视为叶节点,因为那样我们就无法访问其中的中间转换。在这种情况下,需要对源代码进行一些重写。以下是一些例子(非详尽):
- 当尝试通过包含
assert语句的代码进行追踪时,FX 会引发错误。在这种情况下,你可能需要移除该断言,或将其替换为torch._assert(这不是一个公共函数,因此请将其视为权宜之计并谨慎使用)。 - 不支持对张量切片进行原地(in-place)修改的符号追踪。你需要为切片创建一个新变量,应用操作,然后使用串联(concatenation)或堆叠(stacking)重新构建原始张量。
- 在静态图中表示动态控制流在逻辑上是不可能的。看看是否可以将代码逻辑提炼为非动态的内容——有关技巧,请参阅 FX 文档。
通常,你可以查阅 FX 文档,以获取有关符号追踪局限性以及可能的变通方法的更多详细信息。
结论
我们快速回顾了特征提取以及为什么要进行特征提取。尽管现有的 PyTorch 特征提取方法存在显著缺点,但我们了解了 TorchVision 的 FX 特征提取工具的工作原理,以及它相比现有方法的通用性。虽然仍有一些小问题需要解决,但我们了解了其局限性,并可以根据用例在其他方法的局限性之间进行权衡。希望通过将这个新工具添加到你的 PyTorch 工具包中,你现在有能力处理可能遇到的绝大多数特征提取需求。
编码愉快!