跳转到主要内容
博客

使用 Torch FX 在 TorchVision 中进行特征提取

作者: 2021 年 10 月 29 日2024 年 11 月 15 日暂无评论

引言

FX 基于的特征提取是 TorchVision 的一项新实用工具,它允许我们在 PyTorch 模块的前向传播过程中访问输入的中间变换。它通过符号化地跟踪前向方法来生成一个图,其中每个节点代表一个操作。节点的命名方式便于人工阅读,因此可以轻松指定要访问的节点。

这一切听起来有点复杂吗?不用担心,本文中总有一些内容适合所有人。无论您是初学者还是高级深度视觉实践者,您都可能想了解 FX 特征提取。如果您仍然需要更多关于特征提取的背景知识,请继续阅读。如果您已经熟悉了这一点,并想知道如何在 PyTorch 中进行操作,请跳到“PyTorch 中的现有方法:优缺点”。如果您已经了解在 PyTorch 中进行特征提取的挑战,请随意跳到“FX 来救援”。

特征提取回顾

我们都习惯了深度神经网络(DNN)接收输入并产生输出的概念,我们不一定考虑中间发生了什么。我们以 ResNet-50 分类模型为例

CResNet-50 takes an image of a bird and transforms that into the abstract concept 'bird'
图 1:ResNet-50 接收一张鸟的图像,并将其转换为抽象概念“鸟”。来源:ImageNet 中的鸟图像。

然而,我们知道 ResNet-50 架构中有许多顺序的“层”,它们逐步转换输入。在下面的图 2 中,我们深入探讨了 ResNet-50 中的层,并展示了输入通过这些层时的中间变换。

ResNet-50 transforms the input image in multiple steps. Conceptually, we may access the intermediate transformation of the image after each one of these steps.
图 2:ResNet-50 分多步转换输入图像。从概念上讲,我们可以在每一步之后访问图像的中间变换。来源:ImageNet 中的鸟图像。

PyTorch 中的现有方法:优缺点

在基于 FX 的特征提取引入之前,PyTorch 中已经存在几种进行特征提取的方法。

为了说明这些,我们考虑一个执行以下操作的简单卷积神经网络

  • 应用多个“块”,每个块内包含多个卷积层。
  • 在几个块之后,它使用全局平均池化和展平操作。
  • 最后,它使用一个单一的输出分类层。
import torch
from torch import nn


class ConvBlock(nn.Module):
   """
   Applies `num_layers` 3x3 convolutions each followed by ReLU then downsamples
   via 2x2 max pool.
   """

   def __init__(self, num_layers, in_channels, out_channels):
       super().__init__()
       self.convs = nn.ModuleList(
           [nn.Sequential(
               nn.Conv2d(in_channels if i==0 else out_channels, out_channels, 3, padding=1),
               nn.ReLU()
            )
            for i in range(num_layers)]
       )
       self.downsample = nn.MaxPool2d(kernel_size=2, stride=2)
      
   def forward(self, x):
       for conv in self.convs:
           x = conv(x)
       x = self.downsample(x)
       return x
      

class CNN(nn.Module):
   """
   Applies several ConvBlocks each doubling the number of channels, and
   halving the feature map size, before taking a global average and classifying.
   """

   def __init__(self, in_channels, num_blocks, num_classes):
       super().__init__()
       first_channels = 64
       self.blocks = nn.ModuleList(
           [ConvBlock(
               2 if i==0 else 3,
               in_channels=(in_channels if i == 0 else first_channels*(2**(i-1))),
               out_channels=first_channels*(2**i))
            for i in range(num_blocks)]
       )
       self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
       self.cls = nn.Linear(first_channels*(2**(num_blocks-1)), num_classes)

   def forward(self, x):
       for block in self.blocks:
           x = block(x)
       x = self.global_pool(x)
       x = x.flatten(1)
       x = self.cls(x)
       return x


model = CNN(3, 4, 10)
out = model(torch.zeros(1, 3, 32, 32))  # This will be the final logits over classes

假设我们想在全局平均池化之前获取最终的特征图。我们可以这样做

修改 forward 方法

def forward(self, x):
   for block in self.blocks:
       x = block(x)
   self.final_feature_map = x
   x = self.global_pool(x)
   x = x.flatten(1)
   x = self.cls(x)
   return x

或直接返回它

def forward(self, x):
   for block in self.blocks:
       x = block(x)
   final_feature_map = x
   x = self.global_pool(x)
   x = x.flatten(1)
   x = self.cls(x)
   return x, final_feature_map

这看起来很简单。但是这里有一些缺点,它们都源于同一个根本问题:修改源代码并不理想

  • 考虑到项目的实际情况,访问和更改并不总是那么容易。
  • 如果我们想要灵活性(打开或关闭特征提取,或对其进行变体),我们需要进一步修改源代码来支持这一点。
  • 这并不总是只插入一行代码的问题。想想看,如果按照我编写这个模块的方式,你如何从其中一个中间块获取特征图。
  • 总的来说,我们宁愿避免维护模型源代码的开销,因为我们实际上不需要改变它的工作方式。

人们可以看到,当处理更大、更复杂的模型,并试图从嵌套的子模块中获取特征时,这种缺点会变得更加棘手。

使用原始模块的参数编写一个新模块

延续上面的例子,假设我们想从每个块中获取特征图。我们可以这样编写一个新模块

class CNNFeatures(nn.Module):
   def __init__(self, backbone):
       super().__init__()
       self.blocks = backbone.blocks

   def forward(self, x):
       feature_maps = []
       for block in self.blocks:
           x = block(x)
           feature_maps.append(x)
       return feature_maps


backbone = CNN(3, 4, 10)
model = CNNFeatures(backbone)
out = model(torch.zeros(1, 3, 32, 32))  # This is now a list of Tensors, each representing a feature map

事实上,这很像 TorchVision 内部用于制作其许多检测模型的方法。

尽管这种方法解决了一些直接修改源代码的问题,但仍有一些主要缺点

  • 它只能非常直接地访问顶级子模块的输出。处理嵌套子模块会迅速变得复杂。
  • 我们必须小心,不要遗漏输入和输出之间的任何重要操作。我们将引入在将原始模块的确切功能转录到新模块时出现错误的潜在可能性。

总的来说,这种方法和上一种方法都有一个共同的复杂性,即将特征提取与模型的源代码本身联系在一起。事实上,如果我们检查 TorchVision 模型的源代码,我们可能会怀疑某些设计选择受到了以这种方式将其用于下游任务的愿望的影响。

使用钩子(hooks)

钩子将我们从编写源代码的范式转向指定输出的范式。考虑到我们上面的玩具 CNN 示例,以及获取每一层特征图的目标,我们可以像这样使用钩子

model = CNN(3, 4, 10)
feature_maps = []  # This will be a list of Tensors, each representing a feature map

def hook_feat_map(mod, inp, out):
	feature_maps.append(out)

for block in model.blocks:
	block.register_forward_hook(hook_feat_map)

out = model(torch.zeros(1, 3, 32, 32))  # This will be the final logits over classes

现在我们在访问嵌套子模块方面拥有完全的灵活性,并且我们摆脱了修改源代码的责任。但这种方法也有其自身的缺点

  • 我们只能将钩子应用于模块。如果我们有需要其输出的功能操作(reshape、view、功能性非线性等),钩子将无法直接在其上工作。
  • 我们没有修改源代码的任何内容,因此无论钩子如何,整个前向传播都会执行。如果只需要访问早期特征而不需要最终输出,这可能会导致大量无用的计算。
  • 钩子与 TorchScript 不兼容。

以下是不同方法的优缺点总结

无需任何修改或重写即可使用源代码 访问特征的完全灵活性 舍弃不必要的计算步骤 兼容 TorchScript
修改前向方法 从技术上讲,是的。取决于你愿意编写多少代码。所以实际上,否。
重用原始模块的子模块/参数的新模块 从技术上讲,是的。取决于你愿意编写多少代码。所以实际上,否。
钩子 大部分是。仅限于子模块的输出

表 1:使用 PyTorch 进行特征提取的现有方法的一些优缺点

在本文的下一节中,让我们看看如何全面实现“是”。

FX 来救援

此时,对于一些 Python 和编码的新手来说,自然会问:“我们不能直接指向一行代码,然后告诉 Python 或 PyTorch 我们想要那行代码的结果吗?”对于那些花更多时间编码的人来说,无法做到这一点的原因很清楚:一行代码中可能发生多个操作,无论它们是显式写在那里,还是作为子操作隐式存在。以这个简单的模块为例

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.param = torch.nn.Parameter(torch.rand(3, 4))
        self.submodule = MySubModule()

    def forward(self, x):
        return self.submodule(x + self.param).clamp(min=0.0, max=1.0)

前向方法只有一行代码,我们可以将其展开为

  1. self.param 添加到 x
  2. 将 x 传递给 self.submodule。这里我们需要考虑该子模块中发生的步骤。我将只使用虚拟操作名称进行说明:I. submodule.op_1 II. submodule.op_2
  3. 应用 clamp 操作

所以即使我们指向这一行,问题就是:“我们想提取哪个步骤的输出?”。

FX 是一个 PyTorch 核心工具包,它(简化地说)执行了我刚才提到的展开。它执行一种称为“符号跟踪”的操作,这意味着 Python 代码被解释并逐步执行,逐个操作,使用一些虚拟代理作为真实输入。引入一些术语,上面描述的每个步骤都被视为一个**“节点”**,连续的节点相互连接形成一个**“图”**(与常见的图的数学概念类似)。以下是将上述“步骤”翻译成图的概念。


图 3:符号跟踪简单前向方法的示例结果的图形表示。

请注意,我们称之为图,而不仅仅是一组步骤,因为图可以分支并重新组合。想想残差块中的跳跃连接。这看起来像这样


图 4:残差跳跃连接的图形表示。中间节点类似于残差块的主分支,最终节点表示主分支的输入和输出之和。

现在,TorchVision 的 **get_graph_node_names** 函数如上所述应用 FX,并在此过程中为每个节点标记一个人类可读的名称。让我们用上一节中的玩具 CNN 模型来尝试一下

model = CNN(3, 4, 10)
from torchvision.models.feature_extraction import get_graph_node_names
nodes, _ = get_graph_node_names(model)
print(nodes)

这将导致

['x', 'blocks.0.convs.0.0', 'blocks.0.convs.0.1', 'blocks.0.convs.1.0', 'blocks.0.convs.1.1', 'blocks.0.downsample', 'blocks.1.convs.0.0', 'blocks.1.convs.0.1', 'blocks.1.convs.1.0', 'blocks.1.convs.1.1', 'blocks.1.convs.2.0', 'blocks.1.convs.2.1', 'blocks.1.downsample', 'blocks.2.convs.0.0', 'blocks.2.convs.0.1', 'blocks.2.convs.1.0', 'blocks.2.convs.1.1', 'blocks.2.convs.2.0', 'blocks.2.convs.2.1', 'blocks.2.downsample', 'blocks.3.convs.0.0', 'blocks.3.convs.0.1', 'blocks.3.convs.1.0', 'blocks.3.convs.1.1', 'blocks.3.convs.2.0', 'blocks.3.convs.2.1', 'blocks.3.downsample', 'global_pool', 'flatten', 'cls']

我们可以将这些节点名称理解为感兴趣操作的分层组织“地址”。例如,’blocks.1.downsample’ 指的是第二个 ConvBlock 中的 MaxPool2d 层。

create_feature_extractor 是所有魔法发生的地方,它比 **get_graph_node_names** 更进一步。它将所需的节点名称作为输入参数之一,然后使用更多的 FX 核心功能来

  1. 将所需的节点指定为输出。
  2. 修剪不必要的下游节点及其相关参数。
  3. 将生成的图转换回 Python 代码。
  4. 向用户返回另一个 PyTorch 模块。该模块的 forward 方法是第 3 步生成的 Python 代码。

作为演示,以下是如何应用 create_feature_extractor 来从我们的玩具 CNN 模型中获取 4 个特征图

from torchvision.models.feature_extraction import create_feature_extractor
# Confused about the node specification here?
# We are allowed to provide truncated node names, and `create_feature_extractor`
# will choose the last node with that prefix.
feature_extractor = create_feature_extractor(
	model, return_nodes=['blocks.0', 'blocks.1', 'blocks.2', 'blocks.3'])
# `out` will be a dict of Tensors, each representing a feature map
out = feature_extractor(torch.zeros(1, 3, 32, 32))

就这么简单。归根结底,FX 特征提取只是一种方法,它使得我们中的一些人最初编程时天真地希望做到的事情成为可能:“给我这段代码的输出(指着屏幕)”*。

  • ... 不需要我们修改源代码。
  • ... 在访问输入的所有中间变换方面提供了完全的灵活性,无论它们是模块还是功能操作的结果
  • ... 一旦提取了特征,确实会舍弃不必要的计算步骤
  • ... 我之前没有提到这一点,但它也与 TorchScript 兼容!

这是再次添加了 FX 特征提取行后的表格

无需任何修改或重写即可使用源代码 访问特征的完全灵活性 舍弃不必要的计算步骤 兼容 TorchScript
修改前向方法 从技术上讲,是的。取决于你愿意编写多少代码。所以实际上,否。
重用原始模块的子模块/参数的新模块 从技术上讲,是的。取决于你愿意编写多少代码。所以实际上,否。
钩子 大部分是。仅限于子模块的输出
FX

表 2:表 1 的副本,增加了 FX 特征提取行。FX 特征提取在所有方面都获得了“是”!

当前 FX 限制

尽管我很想就此结束这篇文章,但 FX 确实有一些自身的局限性,可以归结为

  1. 在解释和转换为图的步骤中,可能有一些 Python 代码尚未被 FX 处理。
  2. 动态控制流无法用静态图表示。

当出现这些问题时,最简单的方法是将底层代码打包成一个“叶节点”。回想一下图 3 中的示例图?从概念上讲,我们可能同意 submodule 应该被视为一个节点本身,而不是一组表示底层操作的节点。如果我们这样做,我们可以将图重新绘制为


图 5:`submodule` 内的单个操作(左图红色方框内)可以合并为一个节点(右图节点 #2),如果我们将其视为一个“叶”节点。

如果子模块中存在一些有问题的代码,但我们不需要从中提取任何中间变换,我们就会这样做。实际上,通过向 create_feature_extractor 或 get_graph_node_names 提供关键字参数,可以轻松实现这一点。

model = CNN(3, 4, 10)
nodes, _ = get_graph_node_names(model, tracer_kwargs={'leaf_modules': [ConvBlock]})
print(nodes)

其输出将是

['x', 'blocks.0', 'blocks.1', 'blocks.2', 'blocks.3', 'global_pool', 'flatten', 'cls']

请注意,与之前相比,任何给定 ConvBlock 的所有节点都被合并为一个节点。

我们可以对函数做类似的事情。例如,Python 内置的 len 需要被包装,结果应该被视为叶节点。以下是如何使用核心 FX 功能实现这一点

torch.fx.wrap('len')

class MyModule(nn.Module):
   def forward(self, x):
       x += 1
       len(x)

model = MyModule()
feature_extractor = create_feature_extractor(model, return_nodes=['add'])

对于您定义的函数,您可以改用 `create_feature_extractor` 的另一个关键字参数(小细节:为什么您可能更愿意这样做

def myfunc(x):
   return len(x)

class MyModule(nn.Module):
   def forward(self, x):
       x += 1
       myfunc(x)

model = MyModule()
feature_extractor = create_feature_extractor(
   model, return_nodes=['add'], tracer_kwargs={'autowrap_functions': [myfunc]})

请注意,上述所有修复都没有涉及修改源代码。

当然,有时人们试图访问的中间变换恰好位于导致问题的相同前向方法或函数中。在这种情况下,我们不能仅仅将该模块或函数视为叶节点,因为那样我们就无法访问其中的中间变换。在这些情况下,需要对源代码进行一些重写。以下是一些示例(不详尽)

  • 当尝试跟踪包含 assert 语句的代码时,FX 会引发错误。在这种情况下,您可能需要删除该断言或将其替换为 torch._assert(这不是一个公共函数 – 因此将其视为权宜之计,请谨慎使用)。
  • 不支持对张量切片进行符号跟踪就地更改。您需要为切片创建一个新变量,应用操作,然后使用连接或堆叠重建原始张量。
  • 在静态图中表示动态控制流在逻辑上是不可能的。看看你是否可以将编码逻辑精简为非动态的内容——请参阅 FX 文档获取提示。

通常,您可以查阅 FX 文档,了解有关符号跟踪限制和可能的解决方法更详细的信息。

结论

我们快速回顾了特征提取以及为什么有人可能想要这样做。尽管 PyTorch 中存在进行特征提取的现有方法,但它们都有相当显著的缺点。我们了解了 TorchVision 的 FX 特征提取实用工具如何工作以及是什么让它比现有方法如此多功能。虽然后者仍然有一些小问题需要解决,但我们了解了这些限制,并可以根据我们的用例权衡它们与其它方法的限制。希望通过将这个新实用工具添加到您的 PyTorch 工具包中,您现在能够处理绝大多数您可能遇到的特征提取需求。

编程愉快!