博客

使用 Torch FX 在 TorchVision 中进行特征提取

作者 2021年10月29日2024年11月15日暂无评论

引言

基于 FX 的特征提取是 TorchVision 的一项新实用工具,它允许我们在 PyTorch 模块的前向传播过程中获取输入的中间变换结果。它的实现方式是对 forward 方法进行符号追踪(symbolically tracing),从而生成一个图,其中每个节点代表一个单一的操作。节点以人类可读的方式命名,以便用户可以轻松指定想要获取的节点。

听起来有点复杂?别担心,本文内容适合各类人群。无论你是初学者还是资深的深度视觉从业者,你都有可能需要了解 FX 特征提取。如果你想了解更多关于特征提取的背景知识,请继续阅读。如果你已经对此很熟悉,只想知道如何在 PyTorch 中实现,可以直接跳转到“PyTorch 中的现有方法:优缺点”。如果你已经了解在 PyTorch 中进行特征提取所面临的挑战,请随意跳到“FX 来救场”。

特征提取回顾

我们都习惯了深度神经网络 (DNN) 接收输入并产生输出的概念,而不一定会去思考中间发生了什么。让我们以 ResNet-50 分类模型为例:

CResNet-50 takes an image of a bird and transforms that into the abstract concept 'bird'
图 1:ResNet-50 接收一张鸟的图片,并将其转换为抽象概念“鸟”。来源:来自 ImageNet 的鸟类图像。

然而,我们知道在 ResNet-50 架构中有许多顺序排列的“层”,它们一步步地转换输入。在下方的图 2 中,我们深入底层展示了 ResNet-50 内部的层,并展示了输入在通过这些层时的中间变换过程。

ResNet-50 transforms the input image in multiple steps. Conceptually, we may access the intermediate transformation of the image after each one of these steps.
图 2:ResNet-50 分多个步骤转换输入图像。从概念上讲,我们可以在每个步骤之后获取图像的中间变换结果。来源:来自 ImageNet 的鸟类图像。

PyTorch 中的现有方法:优缺点

在引入基于 FX 的特征提取之前,PyTorch 中已经有几种进行特征提取的方法。

为了说明这些方法,我们考虑一个简单的卷积神经网络,它执行以下操作:

  • 应用几个“块”(block),每个块内部包含若干卷积层。
  • 在经过几个块之后,使用全局平均池化和展平操作。
  • 最后使用一个输出分类层。
import torch
from torch import nn


class ConvBlock(nn.Module):
   """
   Applies `num_layers` 3x3 convolutions each followed by ReLU then downsamples
   via 2x2 max pool.
   """

   def __init__(self, num_layers, in_channels, out_channels):
       super().__init__()
       self.convs = nn.ModuleList(
           [nn.Sequential(
               nn.Conv2d(in_channels if i==0 else out_channels, out_channels, 3, padding=1),
               nn.ReLU()
            )
            for i in range(num_layers)]
       )
       self.downsample = nn.MaxPool2d(kernel_size=2, stride=2)
      
   def forward(self, x):
       for conv in self.convs:
           x = conv(x)
       x = self.downsample(x)
       return x
      

class CNN(nn.Module):
   """
   Applies several ConvBlocks each doubling the number of channels, and
   halving the feature map size, before taking a global average and classifying.
   """

   def __init__(self, in_channels, num_blocks, num_classes):
       super().__init__()
       first_channels = 64
       self.blocks = nn.ModuleList(
           [ConvBlock(
               2 if i==0 else 3,
               in_channels=(in_channels if i == 0 else first_channels*(2**(i-1))),
               out_channels=first_channels*(2**i))
            for i in range(num_blocks)]
       )
       self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
       self.cls = nn.Linear(first_channels*(2**(num_blocks-1)), num_classes)

   def forward(self, x):
       for block in self.blocks:
           x = block(x)
       x = self.global_pool(x)
       x = x.flatten(1)
       x = self.cls(x)
       return x


model = CNN(3, 4, 10)
out = model(torch.zeros(1, 3, 32, 32))  # This will be the final logits over classes

假设我们想要获取全局平均池化之前的最终特征图。我们可以采取以下做法:

修改 forward 方法

def forward(self, x):
   for block in self.blocks:
       x = block(x)
   self.final_feature_map = x
   x = self.global_pool(x)
   x = x.flatten(1)
   x = self.cls(x)
   return x

或者直接返回它

def forward(self, x):
   for block in self.blocks:
       x = block(x)
   final_feature_map = x
   x = self.global_pool(x)
   x = x.flatten(1)
   x = self.cls(x)
   return x, final_feature_map

这看起来很容易。但这里存在一些缺点,它们都源于同一个根本问题:修改源代码并不理想。

  • 考虑到项目的实际情况,获取并修改源代码并不总是那么容易。
  • 如果我们想要灵活性(切换特征提取的开关,或对其进行变体修改),我们需要进一步调整源代码以提供支持。
  • 这不仅仅是插入一行代码那么简单。试想一下,按照我编写该模块的方式,你该如何获取中间某个块的特征图。
  • 总的来说,当不需要改变模型工作方式时,我们宁愿避免维护模型源代码所带来的额外负担。

可以看出,在处理更大、更复杂的模型,并试图获取嵌套子模块内部的特征时,这种缺点会变得更加棘手。

使用原模型的参数编写一个新模块

沿用上面的例子,假设我们想从每个块中获取特征图。我们可以这样编写一个新模块:

class CNNFeatures(nn.Module):
   def __init__(self, backbone):
       super().__init__()
       self.blocks = backbone.blocks

   def forward(self, x):
       feature_maps = []
       for block in self.blocks:
           x = block(x)
           feature_maps.append(x)
       return feature_maps


backbone = CNN(3, 4, 10)
model = CNNFeatures(backbone)
out = model(torch.zeros(1, 3, 32, 32))  # This is now a list of Tensors, each representing a feature map

事实上,这非常类似于 TorchVision 内部构建其许多检测模型所使用的方法。

尽管这种方法解决了一些直接修改源代码的问题,但仍存在一些主要缺点:

  • 它只能直观地获取顶层子模块的输出。处理嵌套子模块会迅速变得复杂。
  • 我们必须小心,不能遗漏输入和输出之间的任何重要操作。在将原始模块的功能准确转录到新模块时,我们引入了出错的可能性。

总的来说,这种方法和上一种方法都有一个共同的复杂点,即必须将特征提取与模型的源代码本身绑定在一起。确实,如果我们查看 TorchVision 模型的源代码,我们可能会怀疑某些设计选择是受限于为了在下游任务中以这种方式使用它们。

使用钩子 (Hooks)

Hooks 将我们从编写源代码的范式转向了指定输出的范式。考虑我们上面的 CNN 示例,目标是获取每一层的特征图,我们可以这样使用 Hooks:

model = CNN(3, 4, 10)
feature_maps = []  # This will be a list of Tensors, each representing a feature map

def hook_feat_map(mod, inp, out):
	feature_maps.append(out)

for block in model.blocks:
	block.register_forward_hook(hook_feat_map)

out = model(torch.zeros(1, 3, 32, 32))  # This will be the final logits over classes

现在,我们在访问嵌套子模块方面拥有了完全的灵活性,并且摆脱了修改源代码的负担。但这种方法也有其自身的缺点:

  • 我们只能将 Hooks 应用于模块(nn.Module)。如果我们有想要获取输出的函数操作(reshape, view, 函数式非线性变换等),Hooks 无法直接作用于它们。
  • 我们没有修改源代码的任何内容,因此整个前向传播过程都会执行,无论是否使用了 Hooks。如果我们只需要访问早期的特征而不需要最终输出,这可能会导致大量无用的计算。
  • Hooks 不支持 TorchScript。

以下是不同方法及其优缺点的总结:

可以在无需修改或重写的情况下直接使用源代码 访问特征的完全灵活性 丢弃不必要的计算步骤 支持 TorchScript
修改 forward 方法 技术上可以。取决于你愿意写多少代码。所以在实践中,否。
重用原模型子模块/参数的新模块 技术上可以。取决于你愿意写多少代码。所以在实践中,否。
Hooks 大部分是。仅限于子模块的输出

表 1:PyTorch 特征提取现有方法的部分优缺点

在本文的下一节中,让我们看看如何实现全方位的“是”。

FX 来救场

对于 Python 和编程的新手来说,此时一个自然的问题可能是:“难道我们不能直接指向一行代码,告诉 Python 或 PyTorch 我们想要那一行代码的结果吗?”对于那些有更多编程经验的人来说,原因显而易见:一行代码中可能包含多个操作,无论它们是被显式编写在那里的,还是作为子操作隐含存在的。仅以这个简单的模块为例:

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.param = torch.nn.Parameter(torch.rand(3, 4))
        self.submodule = MySubModule()

    def forward(self, x):
        return self.submodule(x + self.param).clamp(min=0.0, max=1.0)

forward 方法有一行代码,我们可以将其拆解为:

  1. self.param 加到 x
  2. 将 x 传入 self.submodule。这里我们需要考虑该子模块中发生的步骤。为了说明,我将使用虚拟操作名称:I. submodule.op_1, II. submodule.op_2
  3. 应用 clamp 操作

所以,即使我们指向这一行代码,问题依然是:“我们想要提取哪一个步骤的输出?”

FX 是一个核心的 PyTorch 工具包,它(过度简化地说)完成了我刚才提到的拆解工作。它执行所谓的“符号追踪”,这意味着 Python 代码会被解释,并使用某个模拟真实输入的虚拟代理,一步步地执行操作。引入一些术语:如上所述,每个步骤被视为一个“节点”(node),连续的节点连接在一起形成一个“图”(graph)。以下是将上述“步骤”转换为这种图概念的结果:


图 3:对简单 forward 方法示例进行符号追踪结果的图形化表示。

注意我们将其称为图,而不仅仅是一组步骤,因为图可能会分叉和重组。想想残差块中的跳跃连接(skip connection),它看起来会像这样:


图 4:残差跳跃连接的图形化表示。中间节点类似于残差块的主分支,最终节点表示主分支输入和输出的和。

现在,TorchVision 的 get_graph_node_names 函数应用了上述的 FX,并在执行过程中为每个节点打上人类可读的名称。让我们尝试使用上一节中的 CNN 模型:

model = CNN(3, 4, 10)
from torchvision.models.feature_extraction import get_graph_node_names
nodes, _ = get_graph_node_names(model)
print(nodes)

这将产生以下结果:

['x', 'blocks.0.convs.0.0', 'blocks.0.convs.0.1', 'blocks.0.convs.1.0', 'blocks.0.convs.1.1', 'blocks.0.downsample', 'blocks.1.convs.0.0', 'blocks.1.convs.0.1', 'blocks.1.convs.1.0', 'blocks.1.convs.1.1', 'blocks.1.convs.2.0', 'blocks.1.convs.2.1', 'blocks.1.downsample', 'blocks.2.convs.0.0', 'blocks.2.convs.0.1', 'blocks.2.convs.1.0', 'blocks.2.convs.1.1', 'blocks.2.convs.2.0', 'blocks.2.convs.2.1', 'blocks.2.downsample', 'blocks.3.convs.0.0', 'blocks.3.convs.0.1', 'blocks.3.convs.1.0', 'blocks.3.convs.1.1', 'blocks.3.convs.2.0', 'blocks.3.convs.2.1', 'blocks.3.downsample', 'global_pool', 'flatten', 'cls']

我们可以将这些节点名称理解为感兴趣操作的层级化“地址”。例如,‘blocks.1.downsample’ 指的是第二个 ConvBlock 中的 MaxPool2d 层。

create_feature_extractor 是魔力发生的地方,它比 get_graph_node_names 更进一步。它将所需节点名称作为输入参数之一,然后利用更多的 FX 核心功能来:

  1. 将所需的节点指定为输出。
  2. 修剪不必要的下游节点及其关联参数。
  3. 将生成的图转换回 Python 代码。
  4. 返回另一个 PyTorch 模块给用户。该模块将步骤 3 中的 Python 代码作为 forward 方法。

作为一个演示,以下是我们如何应用 create_feature_extractor 来从我们的 CNN 模型中获取 4 个特征图:

from torchvision.models.feature_extraction import create_feature_extractor
# Confused about the node specification here?
# We are allowed to provide truncated node names, and `create_feature_extractor`
# will choose the last node with that prefix.
feature_extractor = create_feature_extractor(
	model, return_nodes=['blocks.0', 'blocks.1', 'blocks.2', 'blocks.3'])
# `out` will be a dict of Tensors, each representing a feature map
out = feature_extractor(torch.zeros(1, 3, 32, 32))

就是这么简单。归根结底,FX 特征提取只是让我们能够实现我们在刚开始编程时天真地希望做到的事情:“直接给我这段代码的输出(手指指向屏幕)”。

  • …不需要我们摆弄源代码。
  • …在访问输入的任何中间变换方面提供了完全的灵活性,无论它们是模块的结果还是函数操作的结果。
  • …在提取特征后确实丢弃了不必要的计算步骤。
  • …我之前没提到,但它也对 TorchScript 友好!

这是之前的表格,添加了关于 FX 特征提取的一行:

可以在无需修改或重写的情况下直接使用源代码 访问特征的完全灵活性 丢弃不必要的计算步骤 支持 TorchScript
修改 forward 方法 技术上可以。取决于你愿意写多少代码。所以在实践中,否。
重用原模型子模块/参数的新模块 技术上可以。取决于你愿意写多少代码。所以在实践中,否。
Hooks 大部分是。仅限于子模块的输出
FX

表 2:表 1 的副本,增加了 FX 特征提取行。FX 特征提取在所有方面都是“是”!

当前 FX 的局限性

虽然我本来想在那结束这篇文章,但 FX 确实有一些局限性,归结为:

  1. 在解释和翻译成图的步骤中,可能仍有一些 Python 代码尚未被 FX 处理。
  2. 动态控制流无法表示为静态图。

当这些问题出现时,最简单的做法是将基础代码捆绑成一个“叶节点”(leaf node)。还记得图 3 的示例图吗?从概念上讲,我们可能认为 submodule 应该被视为一个单独的节点,而不是一组代表底层操作的节点。如果我们这样做,我们可以将图重绘为:


图 5:如果我们将 `submodule` 视为“叶”节点,则 `submodule` 内部的单个操作(左侧 – 红色框内)可以合并为一个节点(右侧 – 节点 #2)。

如果子模块中存在一些有问题的代码,但我们不需要从内部提取任何中间变换,我们就会希望这样做。在实践中,通过向 create_feature_extractor 或 get_graph_node_names 提供关键字参数即可轻松实现。

model = CNN(3, 4, 10)
nodes, _ = get_graph_node_names(model, tracer_kwargs={'leaf_modules': [ConvBlock]})
print(nodes)

其输出为:

['x', 'blocks.0', 'blocks.1', 'blocks.2', 'blocks.3', 'global_pool', 'flatten', 'cls']

注意,与之前相比,任何给定 ConvBlock 的所有节点都被合并成了一个单独的节点。

我们也可以对函数做类似的事情。例如,Python 内置的 len 需要被封装,其结果应被视为叶节点。以下是使用 FX 核心功能实现的方法:

torch.fx.wrap('len')

class MyModule(nn.Module):
   def forward(self, x):
       x += 1
       len(x)

model = MyModule()
feature_extractor = create_feature_extractor(model, return_nodes=['add'])

对于你定义的函数,你可以改用 create_feature_extractor 的另一个关键字参数(小细节:这是你可能想用这种方式的原因)。

def myfunc(x):
   return len(x)

class MyModule(nn.Module):
   def forward(self, x):
       x += 1
       myfunc(x)

model = MyModule()
feature_extractor = create_feature_extractor(
   model, return_nodes=['add'], tracer_kwargs={'autowrap_functions': [myfunc]})

请注意,上述修复方法均不涉及修改源代码。

当然,有时你试图访问的中间变换恰好就在导致问题的同一个 forward 方法或函数内部。在这种情况下,我们不能简单地将该模块或函数视为叶节点,因为这样我们就无法访问其中的中间变换。在这些情况下,需要重写一些源代码。以下是一些示例(并非详尽无遗):

  • 当尝试追踪带有 assert 语句的代码时,FX 会引发错误。在这种情况下,你可能需要删除该断言或将其替换为 torch._assert(这不是公共函数,所以将其视为临时补丁并谨慎使用)。
  • 不支持对张量切片的原地(in-place)修改的符号追踪。你需要为切片创建一个新变量,应用操作,然后使用连接(concatenation)或堆叠(stacking)重建原始张量。
  • 在静态图中表示动态控制流在逻辑上是不可能的。看看是否可以将编码逻辑提炼为非动态的内容 – 请参阅 FX 文档以获取提示。

通常,你可以查阅 FX 文档以获取关于 符号追踪局限性 及其可能的变通方法的更多详细信息。

结论

我们快速回顾了特征提取以及为什么要进行特征提取。虽然 PyTorch 中已经存在用于特征提取的方法,但它们都有相当明显的缺点。我们了解了 TorchVision 的 FX 特征提取工具是如何工作的,以及它与现有方法相比为何如此通用。虽然该方法仍有一些小问题需要解决,但我们了解其局限性,并可以根据我们的用例在其他方法的局限性之间进行权衡。希望通过将此新工具添加到你的 PyTorch 工具包中,你现在有能力处理你可能遇到的绝大多数特征提取需求。

编码愉快!