作者:Alexander Soare 和 Francisco Massa

引言

FX 基于 FX 的特征提取是一种新的 TorchVision 实用工具,它允许我们在 PyTorch 模块的前向传播过程中访问输入的中间转换结果。它通过符号追踪 (symbolically tracing) forward 方法来实现这一点,从而生成一个图,图中的每个节点代表一个单一的操作。节点以人类可读的方式命名,以便可以轻松指定要访问哪些节点。

这一切听起来有点复杂吗?不用担心,本文的内容适合所有人。无论您是初学者还是高级深度视觉实践者,很可能都会想了解 FX 特征提取。如果您想了解更多关于通用特征提取的背景知识,请继续阅读。如果您对此已很熟悉并想了解如何在 PyTorch 中实现,请跳到“PyTorch 中的现有方法:优缺点”。如果您已经了解在 PyTorch 中进行特征提取的挑战,请随意跳到“FX 救援行动”。

特征提取回顾

我们都习惯了深度神经网络 (DNN) 接收输入并产生输出的想法,而我们不一定会去思考中间发生了什么。我们以 ResNet-50 分类模型为例

CResNet-50 takes an image of a bird and transforms that into the abstract concept 'bird'
图 1:ResNet-50 输入一张鸟的图片,并将其转换为抽象概念“鸟”。来源:ImageNet 中的鸟类图片。

但我们知道,ResNet-50 架构中有许多连续的“层”,它们逐步转换输入。在下面的图 2 中,我们查看了 ResNet-50 内部的层,并展示了输入通过这些层时的中间转换结果。

ResNet-50 transforms the input image in multiple steps. Conceptually, we may access the intermediate transformation of the image after each one of these steps.
图 2:ResNet-50 分多步转换输入图像。从概念上讲,我们可以在每个步骤之后访问图像的中间转换结果。来源:ImageNet 中的鸟类图片。

PyTorch 中的现有方法:优缺点

在引入基于 FX 的特征提取之前,PyTorch 中已经存在一些进行特征提取的方法。

为了说明这些方法,我们来看一个简单的卷积神经网络,它执行以下操作:

  • 应用多个“块”,每个块内包含多个卷积层。
  • 经过几个块后,它使用全局平均池化和展平操作。
  • 最后使用一个输出分类层。
import torch
from torch import nn


class ConvBlock(nn.Module):
   """
   Applies `num_layers` 3x3 convolutions each followed by ReLU then downsamples
   via 2x2 max pool.
   """

   def __init__(self, num_layers, in_channels, out_channels):
       super().__init__()
       self.convs = nn.ModuleList(
           [nn.Sequential(
               nn.Conv2d(in_channels if i==0 else out_channels, out_channels, 3, padding=1),
               nn.ReLU()
            )
            for i in range(num_layers)]
       )
       self.downsample = nn.MaxPool2d(kernel_size=2, stride=2)
      
   def forward(self, x):
       for conv in self.convs:
           x = conv(x)
       x = self.downsample(x)
       return x
      

class CNN(nn.Module):
   """
   Applies several ConvBlocks each doubling the number of channels, and
   halving the feature map size, before taking a global average and classifying.
   """

   def __init__(self, in_channels, num_blocks, num_classes):
       super().__init__()
       first_channels = 64
       self.blocks = nn.ModuleList(
           [ConvBlock(
               2 if i==0 else 3,
               in_channels=(in_channels if i == 0 else first_channels*(2**(i-1))),
               out_channels=first_channels*(2**i))
            for i in range(num_blocks)]
       )
       self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
       self.cls = nn.Linear(first_channels*(2**(num_blocks-1)), num_classes)

   def forward(self, x):
       for block in self.blocks:
           x = block(x)
       x = self.global_pool(x)
       x = x.flatten(1)
       x = self.cls(x)
       return x


model = CNN(3, 4, 10)
out = model(torch.zeros(1, 3, 32, 32))  # This will be the final logits over classes

假设我们想在全局平均池化之前获取最终的特征图。我们可以这样做:

修改 forward 方法

def forward(self, x):
   for block in self.blocks:
       x = block(x)
   self.final_feature_map = x
   x = self.global_pool(x)
   x = x.flatten(1)
   x = self.cls(x)
   return x

或者直接返回它

def forward(self, x):
   for block in self.blocks:
       x = block(x)
   final_feature_map = x
   x = self.global_pool(x)
   x = x.flatten(1)
   x = self.cls(x)
   return x, final_feature_map

这看起来相当容易。但这里有一些缺点,它们都源于同一个根本问题:修改源代码并不理想

  • 考虑到项目的实际情况,访问和更改源代码并不总是容易的。
  • 如果我们需要灵活性(开启或关闭特征提取,或者有不同的变体),我们需要进一步修改源代码来支持这一点。
  • 这不总是插入一行代码那么简单。想想看,按照我编写这个模块的方式,如何从其中一个中间块获取特征图。
  • 总的来说,当模型的实际工作方式不需要任何改变时,我们宁愿避免维护模型源代码的额外开销。

可以想象,当处理更大、更复杂的模型,并试图从嵌套子模块中获取特征时,这个缺点会变得更加棘手。

使用原始模块的参数编写新模块

沿用上面的例子,假设我们想从每个块获取特征图。我们可以这样编写一个新模块:

class CNNFeatures(nn.Module):
   def __init__(self, backbone):
       super().__init__()
       self.blocks = backbone.blocks

   def forward(self, x):
       feature_maps = []
       for block in self.blocks:
           x = block(x)
           feature_maps.append(x)
       return feature_maps


backbone = CNN(3, 4, 10)
model = CNNFeatures(backbone)
out = model(torch.zeros(1, 3, 32, 32))  # This is now a list of Tensors, each representing a feature map

事实上,这很像 TorchVision 内部用于构建许多检测模型的方法。

尽管这种方法解决了一些直接修改源代码带来的问题,但仍然存在一些主要缺点

  • 它只方便访问顶层子模块的输出。处理嵌套子模块会迅速变得复杂。
  • 我们必须小心,不要遗漏输入和输出之间的任何重要操作。在将原始模块的精确功能转录到新模块时,我们引入了潜在的错误风险。

总的来说,这种方法和前一种方法都有将特征提取与模型源代码本身绑定在一起的复杂性。事实上,如果我们检查 TorchVision 模型的源代码,我们可能会怀疑一些设计选择受到了以这种方式将它们用于下游任务的需求的影响。

使用 hooks

Hooks 使我们摆脱了编写源代码的范例,转向指定输出的范例。考虑我们上面的玩具 CNN 示例,以及获取每一层特征图的目标,我们可以这样使用 hooks:

model = CNN(3, 4, 10)
feature_maps = []  # This will be a list of Tensors, each representing a feature map

def hook_feat_map(mod, inp, out):
	feature_maps.append(out)

for block in model.blocks:
	block.register_forward_hook(hook_feat_map)

out = model(torch.zeros(1, 3, 32, 32))  # This will be the final logits over classes

现在我们在访问嵌套子模块方面拥有了完全的灵活性,并且摆脱了修改源代码的责任。但这种方法也有其自身的缺点

  • 我们只能将 hooks 应用于模块。如果我们想要获取函数式操作(如 reshape、view、函数式非线性激活等)的输出,hooks 无法直接作用于它们。
  • 我们没有修改源代码的任何内容,因此整个前向传播都会执行,无论是否使用了 hooks。如果我们只需要访问早期特征而不需要最终输出,这可能会导致大量无用的计算。
  • Hooks 不支持 TorchScript。

以下是不同方法及其优缺点的总结:

  可以使用源代码,无需任何修改或重写 访问特征的完全灵活性 丢弃不必要的计算步骤 TorchScript 兼容
修改 forward 方法 技术上可以。取决于你愿意写多少代码。所以实际上是“否”。
重用原始模块的子模块/参数的新模块 技术上可以。取决于你愿意写多少代码。所以实际上是“否”。
Hooks 大部分是。仅限于子模块的输出

表 1:PyTorch 中现有特征提取方法的优点(或缺点)

在本文的下一节中,我们将看看如何实现全面“是”。

FX 救援行动

对于一些 Python 和编程新手来说,此时可能会自然地问:“我们不能直接指向一行代码,然后告诉 Python 或 PyTorch 我们想要那行代码的结果吗?”对于那些有更多编程经验的人来说,原因很明显:一行代码中可以发生多个操作,无论它们是显式写在那里,还是作为子操作隐式存在。就拿这个简单的模块为例

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.param = torch.nn.Parameter(torch.rand(3, 4))
        self.submodule = MySubModule()

    def forward(self, x):
        return self.submodule(x + self.param).clamp(min=0.0, max=1.0)

forward 方法只有一行代码,我们可以将其展开为:

  1. self.param 添加到 x
  2. 将 x 通过 self.submodule。这里我们需要考虑该子模块中发生的步骤。我将使用一些虚拟操作名称进行说明:I. submodule.op_1 II. submodule.op_2
  3. 应用 clamp 操作

所以即使我们指向这一行,问题是:“我们想提取哪个步骤的输出?”

FX 是 PyTorch 的一个核心工具包,它(简单来说)执行了我刚才提到的展开操作。它执行一种称为“符号追踪 (symbolic tracing)”的操作,这意味着 Python 代码会被解释并逐个操作地执行,使用一个虚拟代理作为真实输入。引入一些术语,上述的每个步骤被视为一个 “节点”,连续的节点相互连接形成一个 “图” (这与数学中图的概念并无太大区别)。以下是将上述“步骤”转换为图的概念的表示。

Graphical representation of the result of symbolically tracing our example of a simple forward method.
图 3:符号追踪我们的简单 forward 方法示例所得结果的图形表示。

请注意,我们将其称为图,而不仅仅是一组步骤,因为图可以分支和重新组合。想想残差块中的跳跃连接 (skip connection)。它看起来像这样:

Graphical representation of a residual skip connection. The middle node is like the main branch of a residual block, and the final node represents the sum of the input and output of the main branch.
图 4:残差跳跃连接的图形表示。中间节点类似于残差块的主分支,最终节点表示主分支的输入和输出之和。

现在,TorchVision 的 get_graph_node_names 函数应用了上述 FX 方法,并在过程中为每个节点标记了一个人类可读的名称。让我们用上一节中的玩具 CNN 模型来尝试一下

model = CNN(3, 4, 10)
from torchvision.models.feature_extraction import get_graph_node_names
nodes, _ = get_graph_node_names(model)
print(nodes)

结果如下:

['x', 'blocks.0.convs.0.0', 'blocks.0.convs.0.1', 'blocks.0.convs.1.0', 'blocks.0.convs.1.1', 'blocks.0.downsample', 'blocks.1.convs.0.0', 'blocks.1.convs.0.1', 'blocks.1.convs.1.0', 'blocks.1.convs.1.1', 'blocks.1.convs.2.0', 'blocks.1.convs.2.1', 'blocks.1.downsample', 'blocks.2.convs.0.0', 'blocks.2.convs.0.1', 'blocks.2.convs.1.0', 'blocks.2.convs.1.1', 'blocks.2.convs.2.0', 'blocks.2.convs.2.1', 'blocks.2.downsample', 'blocks.3.convs.0.0', 'blocks.3.convs.0.1', 'blocks.3.convs.1.0', 'blocks.3.convs.1.1', 'blocks.3.convs.2.0', 'blocks.3.convs.2.1', 'blocks.3.downsample', 'global_pool', 'flatten', 'cls']

我们可以将这些节点名称理解为分层组织的、感兴趣的操作的“地址”。例如,‘blocks.1.downsample’ 指的是第二个 ConvBlock 中的 MaxPool2d 层。

create_feature_extractor,这是所有神奇之处发生的地方,比 get_graph_node_names 更进一步。它将期望的节点名称作为输入参数之一,然后使用更多的 FX 核心功能来

  1. 将期望的节点指定为输出。
  2. 裁剪不必要的下游节点及其相关参数。
  3. 将生成的图转换回 Python 代码。
  4. 向用户返回另一个 PyTorch 模块。这个模块的 forward 方法就是步骤 3 中的 Python 代码。

作为一个演示,以下是如何应用 create_feature_extractor 从我们的玩具 CNN 模型中获取 4 个特征图:

from torchvision.models.feature_extraction import create_feature_extractor
# Confused about the node specification here?
# We are allowed to provide truncated node names, and `create_feature_extractor`
# will choose the last node with that prefix.
feature_extractor = create_feature_extractor(
	model, return_nodes=['blocks.0', 'blocks.1', 'blocks.2', 'blocks.3'])
# `out` will be a dict of Tensors, each representing a feature map
out = feature_extractor(torch.zeros(1, 3, 32, 32))

就这么简单。归根结底,FX 特征提取只是一种方法,它使我们初学编程时可能天真地希望实现的事情成为可能:“把这段代码的输出给我(手指屏幕)”*。

  • … 不需要我们修改源代码。
  • … 在访问输入数据的任何中间转换结果方面提供了完全的灵活性,无论它们是模块还是函数式操作的结果
  • … 在提取特征后会丢弃不必要的计算步骤
  • … 我之前没有提到这一点,但它也支持 TorchScript!

这是那张表格,增加了 FX 特征提取一行

  可以使用源代码,无需任何修改或重写 访问特征的完全灵活性 丢弃不必要的计算步骤 TorchScript 兼容
修改 forward 方法 技术上可以。取决于你愿意写多少代码。所以实际上是“否”。
重用原始模块的子模块/参数的新模块 技术上可以。取决于你愿意写多少代码。所以实际上是“否”。
Hooks 大部分是。仅限于子模块的输出
FX

表 2:表 1 的副本,增加了 FX 特征提取一行。FX 特征提取全面实现“是”!

FX 的当前限制

尽管我很想就此结束这篇文章,但 FX 确实有一些自身的限制,主要归结为:

  1. 在将 Python 代码解释并转换为图的步骤中,可能存在一些 FX 尚不支持的代码。
  2. 动态控制流无法用静态图表示。

当出现这些问题时,最简单的做法是将底层代码打包到一个“叶子节点”(leaf node) 中。还记得图 3 中的示例图吗?从概念上讲,我们可以同意 submodule 本身应该被视为一个节点,而不是一组代表底层操作的节点。如果我们这样做,可以将图重新绘制为:

The individual operations within `submodule` may (left - within red box), may be consolidated into one node (right - node #2) if we consider the `submodule` as a 'leaf' node.
图 5:submodule 中的单个操作(左 - 红色框内)如果将 submodule 视为一个“叶子”节点,则可以合并为一个节点(右 - 节点 #2)。

如果在子模块中存在一些有问题代码,但我们不需要从其中提取任何中间转换结果,那么我们就会这样做。实际上,这可以通过向 create_feature_extractor 或 get_graph_node_names 提供一个关键字参数轻松实现。

model = CNN(3, 4, 10)
nodes, _ = get_graph_node_names(model, tracer_kwargs={'leaf_modules': [ConvBlock]})
print(nodes)

其输出将是:

['x', 'blocks.0', 'blocks.1', 'blocks.2', 'blocks.3', 'global_pool', 'flatten', 'cls']

请注意,与之前相比,任何给定的 ConvBlock 的所有节点都合并为一个节点。

我们可以对函数执行类似的操作。例如,Python 的内置函数 len 需要被包装,并且结果应该被视为一个叶子节点。以下是如何使用 FX 核心功能来实现这一点:

torch.fx.wrap('len')

class MyModule(nn.Module):
   def forward(self, x):
       x += 1
       len(x)

model = MyModule()
feature_extractor = create_feature_extractor(model, return_nodes=['add'])

对于您定义的函数,您可以使用另一个关键字参数传递给 create_feature_extractor(次要细节:为什么您可能想这样做而不是那样

def myfunc(x):
   return len(x)

class MyModule(nn.Module):
   def forward(self, x):
       x += 1
       myfunc(x)

model = MyModule()
feature_extractor = create_feature_extractor(
   model, return_nodes=['add'], tracer_kwargs={'autowrap_functions': [myfunc]})

请注意,上述所有修复方法都不涉及修改源代码。

当然,有时您尝试访问的中间转换结果恰好位于导致问题的 forward 方法或函数内部。在这种情况下,我们不能简单地将该模块或函数视为叶子节点,因为那样就无法访问其中的中间转换结果。在这些情况下,需要对源代码进行一些重写。以下是一些示例(非详尽列表):

  • 当尝试追踪包含 assert 语句的代码时,FX 会引发错误。在这种情况下,您可能需要删除该断言或将其替换为 torch._assert(这不是一个公共函数 - 因此请将其视为权宜之计,谨慎使用。)
  • 不支持符号追踪张量切片的就地修改。您需要为切片创建一个新变量,应用操作,然后使用连接或堆叠来重构原始张量。
  • 在静态图中表示动态控制流在逻辑上是不可能的。看看您是否能将代码逻辑精炼为非动态的内容 - 请参阅 FX 文档获取提示。

总的来说,您可以查阅 FX 文档,获取更多关于 符号追踪的限制 和可能的解决方法 的详细信息。

结论

我们快速回顾了特征提取以及为什么要进行特征提取。尽管 PyTorch 中存在一些进行特征提取的现有方法,但它们都存在相当显著的缺点。我们了解了 TorchVision 的 FX 特征提取实用工具是如何工作的,以及与现有方法相比,它为何如此通用多能。尽管后者(指 FX)还有一些小问题需要解决,但我们理解其限制,并且可以根据我们的用例将其与其他方法的限制进行权衡。希望通过将这个新的实用工具添加到您的 PyTorch 工具包中,您现在能够处理您可能遇到的绝大多数特征提取需求。

编程愉快!