• 教程 >
  • (beta) 使用 FX 构建一个简单的 CPU 性能分析器
快捷方式

(beta) 使用 FX 构建一个简单的 CPU 性能分析器

创建于: Mar 04, 2021 | 最后更新: Jan 16, 2024 | 最后验证: 未验证

作者: James Reed

在本教程中,我们将使用 FX 来完成以下任务:

  1. 捕获 PyTorch Python 代码,以便我们可以检查并收集代码结构和执行情况的统计信息

  2. 构建一个小型类,用作简单的性能“分析器”,从实际运行中收集模型每个部分的运行时统计信息。

在本教程中,我们将使用 torchvision 的 ResNet18 模型进行演示。

import torch
import torch.fx
import torchvision.models as models

rn18 = models.resnet18()
rn18.eval()
ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer2): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer3): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer4): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc): Linear(in_features=512, out_features=1000, bias=True)
)

现在我们有了模型,我们想深入研究其性能。也就是说,对于以下调用,模型的哪些部分耗时最长?

input = torch.randn(5, 3, 224, 224)
output = rn18(input)

回答这个问题的常见方法是遍历程序源代码,添加代码在程序的各个点收集时间戳,并比较这些时间戳之间的差异以查看这些区域花费了多长时间。

这种技术当然适用于 PyTorch 代码,但是如果我们不必复制模型代码并对其进行编辑,特别是我们没有编写的代码(例如这个 torchvision 模型),那会更好。相反,我们将使用 FX 来自动化这种“检测”过程,而无需修改任何源代码。

首先,让我们导入一些库(我们将在后面的代码中全部使用它们)。

import statistics, tabulate, time
from typing import Any, Dict, List
from torch.fx import Interpreter

注意

tabulate 是一个外部库,不是 PyTorch 的依赖项。我们将使用它更方便地可视化性能数据。请确保你已经从你喜欢的 Python 包源安装了它。

使用符号跟踪捕获模型

接下来,我们将使用 FX 的符号跟踪机制来捕获模型的定义,并将其存储在一个我们可以操作和检查的数据结构中。

traced_rn18 = torch.fx.symbolic_trace(rn18)
print(traced_rn18.graph)
graph():
    %x : torch.Tensor [num_users=1] = placeholder[target=x]
    %conv1 : [num_users=1] = call_module[target=conv1](args = (%x,), kwargs = {})
    %bn1 : [num_users=1] = call_module[target=bn1](args = (%conv1,), kwargs = {})
    %relu : [num_users=1] = call_module[target=relu](args = (%bn1,), kwargs = {})
    %maxpool : [num_users=2] = call_module[target=maxpool](args = (%relu,), kwargs = {})
    %layer1_0_conv1 : [num_users=1] = call_module[target=layer1.0.conv1](args = (%maxpool,), kwargs = {})
    %layer1_0_bn1 : [num_users=1] = call_module[target=layer1.0.bn1](args = (%layer1_0_conv1,), kwargs = {})
    %layer1_0_relu : [num_users=1] = call_module[target=layer1.0.relu](args = (%layer1_0_bn1,), kwargs = {})
    %layer1_0_conv2 : [num_users=1] = call_module[target=layer1.0.conv2](args = (%layer1_0_relu,), kwargs = {})
    %layer1_0_bn2 : [num_users=1] = call_module[target=layer1.0.bn2](args = (%layer1_0_conv2,), kwargs = {})
    %add : [num_users=1] = call_function[target=operator.add](args = (%layer1_0_bn2, %maxpool), kwargs = {})
    %layer1_0_relu_1 : [num_users=2] = call_module[target=layer1.0.relu](args = (%add,), kwargs = {})
    %layer1_1_conv1 : [num_users=1] = call_module[target=layer1.1.conv1](args = (%layer1_0_relu_1,), kwargs = {})
    %layer1_1_bn1 : [num_users=1] = call_module[target=layer1.1.bn1](args = (%layer1_1_conv1,), kwargs = {})
    %layer1_1_relu : [num_users=1] = call_module[target=layer1.1.relu](args = (%layer1_1_bn1,), kwargs = {})
    %layer1_1_conv2 : [num_users=1] = call_module[target=layer1.1.conv2](args = (%layer1_1_relu,), kwargs = {})
    %layer1_1_bn2 : [num_users=1] = call_module[target=layer1.1.bn2](args = (%layer1_1_conv2,), kwargs = {})
    %add_1 : [num_users=1] = call_function[target=operator.add](args = (%layer1_1_bn2, %layer1_0_relu_1), kwargs = {})
    %layer1_1_relu_1 : [num_users=2] = call_module[target=layer1.1.relu](args = (%add_1,), kwargs = {})
    %layer2_0_conv1 : [num_users=1] = call_module[target=layer2.0.conv1](args = (%layer1_1_relu_1,), kwargs = {})
    %layer2_0_bn1 : [num_users=1] = call_module[target=layer2.0.bn1](args = (%layer2_0_conv1,), kwargs = {})
    %layer2_0_relu : [num_users=1] = call_module[target=layer2.0.relu](args = (%layer2_0_bn1,), kwargs = {})
    %layer2_0_conv2 : [num_users=1] = call_module[target=layer2.0.conv2](args = (%layer2_0_relu,), kwargs = {})
    %layer2_0_bn2 : [num_users=1] = call_module[target=layer2.0.bn2](args = (%layer2_0_conv2,), kwargs = {})
    %layer2_0_downsample_0 : [num_users=1] = call_module[target=layer2.0.downsample.0](args = (%layer1_1_relu_1,), kwargs = {})
    %layer2_0_downsample_1 : [num_users=1] = call_module[target=layer2.0.downsample.1](args = (%layer2_0_downsample_0,), kwargs = {})
    %add_2 : [num_users=1] = call_function[target=operator.add](args = (%layer2_0_bn2, %layer2_0_downsample_1), kwargs = {})
    %layer2_0_relu_1 : [num_users=2] = call_module[target=layer2.0.relu](args = (%add_2,), kwargs = {})
    %layer2_1_conv1 : [num_users=1] = call_module[target=layer2.1.conv1](args = (%layer2_0_relu_1,), kwargs = {})
    %layer2_1_bn1 : [num_users=1] = call_module[target=layer2.1.bn1](args = (%layer2_1_conv1,), kwargs = {})
    %layer2_1_relu : [num_users=1] = call_module[target=layer2.1.relu](args = (%layer2_1_bn1,), kwargs = {})
    %layer2_1_conv2 : [num_users=1] = call_module[target=layer2.1.conv2](args = (%layer2_1_relu,), kwargs = {})
    %layer2_1_bn2 : [num_users=1] = call_module[target=layer2.1.bn2](args = (%layer2_1_conv2,), kwargs = {})
    %add_3 : [num_users=1] = call_function[target=operator.add](args = (%layer2_1_bn2, %layer2_0_relu_1), kwargs = {})
    %layer2_1_relu_1 : [num_users=2] = call_module[target=layer2.1.relu](args = (%add_3,), kwargs = {})
    %layer3_0_conv1 : [num_users=1] = call_module[target=layer3.0.conv1](args = (%layer2_1_relu_1,), kwargs = {})
    %layer3_0_bn1 : [num_users=1] = call_module[target=layer3.0.bn1](args = (%layer3_0_conv1,), kwargs = {})
    %layer3_0_relu : [num_users=1] = call_module[target=layer3.0.relu](args = (%layer3_0_bn1,), kwargs = {})
    %layer3_0_conv2 : [num_users=1] = call_module[target=layer3.0.conv2](args = (%layer3_0_relu,), kwargs = {})
    %layer3_0_bn2 : [num_users=1] = call_module[target=layer3.0.bn2](args = (%layer3_0_conv2,), kwargs = {})
    %layer3_0_downsample_0 : [num_users=1] = call_module[target=layer3.0.downsample.0](args = (%layer2_1_relu_1,), kwargs = {})
    %layer3_0_downsample_1 : [num_users=1] = call_module[target=layer3.0.downsample.1](args = (%layer3_0_downsample_0,), kwargs = {})
    %add_4 : [num_users=1] = call_function[target=operator.add](args = (%layer3_0_bn2, %layer3_0_downsample_1), kwargs = {})
    %layer3_0_relu_1 : [num_users=2] = call_module[target=layer3.0.relu](args = (%add_4,), kwargs = {})
    %layer3_1_conv1 : [num_users=1] = call_module[target=layer3.1.conv1](args = (%layer3_0_relu_1,), kwargs = {})
    %layer3_1_bn1 : [num_users=1] = call_module[target=layer3.1.bn1](args = (%layer3_1_conv1,), kwargs = {})
    %layer3_1_relu : [num_users=1] = call_module[target=layer3.1.relu](args = (%layer3_1_bn1,), kwargs = {})
    %layer3_1_conv2 : [num_users=1] = call_module[target=layer3.1.conv2](args = (%layer3_1_relu,), kwargs = {})
    %layer3_1_bn2 : [num_users=1] = call_module[target=layer3.1.bn2](args = (%layer3_1_conv2,), kwargs = {})
    %add_5 : [num_users=1] = call_function[target=operator.add](args = (%layer3_1_bn2, %layer3_0_relu_1), kwargs = {})
    %layer3_1_relu_1 : [num_users=2] = call_module[target=layer3.1.relu](args = (%add_5,), kwargs = {})
    %layer4_0_conv1 : [num_users=1] = call_module[target=layer4.0.conv1](args = (%layer3_1_relu_1,), kwargs = {})
    %layer4_0_bn1 : [num_users=1] = call_module[target=layer4.0.bn1](args = (%layer4_0_conv1,), kwargs = {})
    %layer4_0_relu : [num_users=1] = call_module[target=layer4.0.relu](args = (%layer4_0_bn1,), kwargs = {})
    %layer4_0_conv2 : [num_users=1] = call_module[target=layer4.0.conv2](args = (%layer4_0_relu,), kwargs = {})
    %layer4_0_bn2 : [num_users=1] = call_module[target=layer4.0.bn2](args = (%layer4_0_conv2,), kwargs = {})
    %layer4_0_downsample_0 : [num_users=1] = call_module[target=layer4.0.downsample.0](args = (%layer3_1_relu_1,), kwargs = {})
    %layer4_0_downsample_1 : [num_users=1] = call_module[target=layer4.0.downsample.1](args = (%layer4_0_downsample_0,), kwargs = {})
    %add_6 : [num_users=1] = call_function[target=operator.add](args = (%layer4_0_bn2, %layer4_0_downsample_1), kwargs = {})
    %layer4_0_relu_1 : [num_users=2] = call_module[target=layer4.0.relu](args = (%add_6,), kwargs = {})
    %layer4_1_conv1 : [num_users=1] = call_module[target=layer4.1.conv1](args = (%layer4_0_relu_1,), kwargs = {})
    %layer4_1_bn1 : [num_users=1] = call_module[target=layer4.1.bn1](args = (%layer4_1_conv1,), kwargs = {})
    %layer4_1_relu : [num_users=1] = call_module[target=layer4.1.relu](args = (%layer4_1_bn1,), kwargs = {})
    %layer4_1_conv2 : [num_users=1] = call_module[target=layer4.1.conv2](args = (%layer4_1_relu,), kwargs = {})
    %layer4_1_bn2 : [num_users=1] = call_module[target=layer4.1.bn2](args = (%layer4_1_conv2,), kwargs = {})
    %add_7 : [num_users=1] = call_function[target=operator.add](args = (%layer4_1_bn2, %layer4_0_relu_1), kwargs = {})
    %layer4_1_relu_1 : [num_users=1] = call_module[target=layer4.1.relu](args = (%add_7,), kwargs = {})
    %avgpool : [num_users=1] = call_module[target=avgpool](args = (%layer4_1_relu_1,), kwargs = {})
    %flatten : [num_users=1] = call_function[target=torch.flatten](args = (%avgpool, 1), kwargs = {})
    %fc : [num_users=1] = call_module[target=fc](args = (%flatten,), kwargs = {})
    return fc

这为我们提供了 ResNet18 模型的 Graph 表示。一个 Graph 由一系列相互连接的 Node 组成。每个 Node 表示 Python 代码中的一个调用点(无论是函数、模块还是方法),而边(表示为每个节点上的 argskwargs)表示这些调用点之间传递的值。有关 Graph 表示和 FX 其余 API 的更多信息,请参阅 FX 文档 https://pytorch.ac.cn/docs/master/fx.html

创建一个分析解释器

接下来,我们将创建一个继承自 torch.fx.Interpreter 的类。虽然 symbolic_trace 产生的 GraphModule 会编译 Python 代码,并在调用 GraphModule 时运行,但运行 GraphModule 的另一种方法是逐个执行 Graph 中的每个 Node。这就是 Interpreter 提供的功能:它逐节点地解释图。

通过继承自 Interpreter,我们可以重写各种功能并安装我们想要的分析行为。目标是拥有一个对象,我们可以向其传递模型,调用模型 1 次或多次,然后获取有关模型及其每个部分在这些运行期间花费了多长时间的统计信息。

让我们定义 ProfilingInterpreter

class ProfilingInterpreter(Interpreter):
    def __init__(self, mod : torch.nn.Module):
        # Rather than have the user symbolically trace their model,
        # we're going to do it in the constructor. As a result, the
        # user can pass in any ``Module`` without having to worry about
        # symbolic tracing APIs
        gm = torch.fx.symbolic_trace(mod)
        super().__init__(gm)

        # We are going to store away two things here:
        #
        # 1. A list of total runtimes for ``mod``. In other words, we are
        #    storing away the time ``mod(...)`` took each time this
        #    interpreter is called.
        self.total_runtime_sec : List[float] = []
        # 2. A map from ``Node`` to a list of times (in seconds) that
        #    node took to run. This can be seen as similar to (1) but
        #    for specific sub-parts of the model.
        self.runtimes_sec : Dict[torch.fx.Node, List[float]] = {}

    ######################################################################
    # Next, let's override our first method: ``run()``. ``Interpreter``'s ``run``
    # method is the top-level entry point for execution of the model. We will
    # want to intercept this so that we can record the total runtime of the
    # model.

    def run(self, *args) -> Any:
        # Record the time we started running the model
        t_start = time.time()
        # Run the model by delegating back into Interpreter.run()
        return_val = super().run(*args)
        # Record the time we finished running the model
        t_end = time.time()
        # Store the total elapsed time this model execution took in the
        # ``ProfilingInterpreter``
        self.total_runtime_sec.append(t_end - t_start)
        return return_val

    ######################################################################
    # Now, let's override ``run_node``. ``Interpreter`` calls ``run_node`` each
    # time it executes a single node. We will intercept this so that we
    # can measure and record the time taken for each individual call in
    # the model.

    def run_node(self, n : torch.fx.Node) -> Any:
        # Record the time we started running the op
        t_start = time.time()
        # Run the op by delegating back into Interpreter.run_node()
        return_val = super().run_node(n)
        # Record the time we finished running the op
        t_end = time.time()
        # If we don't have an entry for this node in our runtimes_sec
        # data structure, add one with an empty list value.
        self.runtimes_sec.setdefault(n, [])
        # Record the total elapsed time for this single invocation
        # in the runtimes_sec data structure
        self.runtimes_sec[n].append(t_end - t_start)
        return return_val

    ######################################################################
    # Finally, we are going to define a method (one which doesn't override
    # any ``Interpreter`` method) that provides us a nice, organized view of
    # the data we have collected.

    def summary(self, should_sort : bool = False) -> str:
        # Build up a list of summary information for each node
        node_summaries : List[List[Any]] = []
        # Calculate the mean runtime for the whole network. Because the
        # network may have been called multiple times during profiling,
        # we need to summarize the runtimes. We choose to use the
        # arithmetic mean for this.
        mean_total_runtime = statistics.mean(self.total_runtime_sec)

        # For each node, record summary statistics
        for node, runtimes in self.runtimes_sec.items():
            # Similarly, compute the mean runtime for ``node``
            mean_runtime = statistics.mean(runtimes)
            # For easier understanding, we also compute the percentage
            # time each node took with respect to the whole network.
            pct_total = mean_runtime / mean_total_runtime * 100
            # Record the node's type, name of the node, mean runtime, and
            # percent runtime.
            node_summaries.append(
                [node.op, str(node), mean_runtime, pct_total])

        # One of the most important questions to answer when doing performance
        # profiling is "Which op(s) took the longest?". We can make this easy
        # to see by providing sorting functionality in our summary view
        if should_sort:
            node_summaries.sort(key=lambda s: s[2], reverse=True)

        # Use the ``tabulate`` library to create a well-formatted table
        # presenting our summary information
        headers : List[str] = [
            'Op type', 'Op', 'Average runtime (s)', 'Pct total runtime'
        ]
        return tabulate.tabulate(node_summaries, headers=headers)

注意

我们使用 Python 的 time.time 函数获取实时时钟时间戳并进行比较。这不是衡量性能最准确的方法,并且只能给我们提供一级近似。在本教程中,我们仅出于演示目的使用这种简单技术。

调查 ResNet18 的性能

我们现在可以使用 ProfilingInterpreter 来检查 ResNet18 模型的性能特征;

interp = ProfilingInterpreter(rn18)
interp.run(input)
print(interp.summary(True))
Op type        Op                       Average runtime (s)    Pct total runtime
-------------  ---------------------  ---------------------  -------------------
call_module    maxpool                          0.00606203            10.1653
call_module    conv1                            0.00542283             9.09342
call_module    layer4_0_conv2                   0.00363326             6.09253
call_module    layer1_0_conv1                   0.00321054             5.38369
call_module    layer1_0_conv2                   0.00315976             5.29853
call_module    layer4_1_conv1                   0.00309634             5.19218
call_module    layer4_1_conv2                   0.00293541             4.92232
call_module    layer1_1_conv2                   0.00277257             4.64926
call_module    layer1_1_conv1                   0.00243998             4.09154
call_module    layer3_1_conv2                   0.00232077             3.89164
call_module    layer2_1_conv1                   0.0022316              3.74211
call_module    layer2_1_conv2                   0.00211406             3.54501
call_module    layer3_0_conv2                   0.00210214             3.52502
call_module    layer3_1_conv1                   0.00208902             3.50303
call_module    layer2_0_conv2                   0.0020225              3.39149
call_module    layer4_0_conv1                   0.00196433             3.29394
call_module    bn1                              0.00139451             2.33842
call_module    layer2_0_conv1                   0.00133061             2.23128
call_module    layer3_0_conv1                   0.00124478             2.08735
call_module    layer2_0_downsample_0            0.00108624             1.82148
call_module    layer4_0_downsample_0            0.000471592            0.790801
call_module    layer3_0_downsample_0            0.000450134            0.75482
call_function  add                              0.000432968            0.726034
call_function  add_1                            0.000419855            0.704045
call_module    relu                             0.000311613            0.522537
call_module    layer1_0_bn1                     0.000288248            0.483356
call_module    layer1_0_bn2                     0.000271082            0.454571
call_module    layer1_1_bn2                     0.000260115            0.43618
call_module    fc                               0.000248432            0.41659
call_function  add_3                            0.000231981            0.389004
call_module    layer2_1_bn2                     0.000171661            0.287855
call_module    layer2_1_bn1                     0.000165224            0.27706
call_module    layer2_0_downsample_1            0.000153065            0.256671
call_module    layer1_1_bn1                     0.000150442            0.252273
call_module    avgpool                          0.000132084            0.221488
call_module    layer3_1_bn2                     0.000120401            0.201898
call_module    layer3_1_bn1                     0.000115395            0.193502
call_module    layer4_1_bn2                     0.000115395            0.193502
call_module    layer4_0_bn2                     0.000113249            0.189904
call_module    layer1_0_relu                    9.799e-05              0.164317
call_module    layer3_0_bn2                     9.60827e-05            0.161119
call_module    layer1_0_relu_1                  9.58443e-05            0.160719
call_module    layer2_0_bn1                     9.10759e-05            0.152723
call_module    layer2_0_bn2                     9.08375e-05            0.152323
call_module    layer4_1_bn1                     8.58307e-05            0.143927
call_function  add_2                            8.27312e-05            0.13873
call_module    layer1_1_relu_1                  8.2016e-05             0.137531
call_function  add_5                            7.93934e-05            0.133133
call_function  add_7                            7.67708e-05            0.128735
call_module    layer4_0_downsample_1            7.53403e-05            0.126336
call_module    layer4_0_bn1                     7.43866e-05            0.124737
call_module    layer1_1_relu                    7.29561e-05            0.122338
call_module    layer3_0_downsample_1            7.03335e-05            0.117941
call_module    layer3_0_bn1                     6.86646e-05            0.115142
call_function  add_6                            6.67572e-05            0.111944
call_module    layer4_0_relu                    6.36578e-05            0.106746
call_function  add_4                            5.74589e-05            0.0963514
call_module    layer4_1_relu                    5.34058e-05            0.0895549
call_module    layer2_0_relu_1                  5.22137e-05            0.0875559
call_module    layer4_0_relu_1                  5.05447e-05            0.0847573
call_module    layer2_1_relu                    4.69685e-05            0.0787603
call_module    layer2_1_relu_1                  4.69685e-05            0.0787603
call_module    layer2_0_relu                    4.673e-05              0.0783605
call_module    layer4_1_relu_1                  4.33922e-05            0.0727633
call_module    layer3_1_relu                    3.83854e-05            0.0643676
call_module    layer3_1_relu_1                  3.74317e-05            0.0627684
call_module    layer3_0_relu                    3.69549e-05            0.0619688
call_module    layer3_0_relu_1                  3.69549e-05            0.0619688
call_function  flatten                          2.88486e-05            0.0483756
placeholder    x                                1.57356e-05            0.0263867
output         output                           9.77516e-06            0.0163917

这里有两点需要指出:

结论

正如我们所见,使用 FX,我们可以轻松地捕获 PyTorch 程序(即使是我们没有源代码的程序!)并将其转换为机器可解释的格式,并用于分析,例如我们在这里进行的性能分析。FX 为处理 PyTorch 程序打开了一个激动人心的世界。

最后,由于 FX 仍处于测试阶段,我们非常乐意听取您在使用它方面的任何反馈意见。请随时使用 PyTorch 论坛 (https://discuss.pytorch.org/) 和问题追踪器 (https://github.com/pytorch/pytorch/issues) 提供您可能拥有的任何反馈。

脚本总运行时间: ( 0 分 0.292 秒)

Gallery 生成自 Sphinx-Gallery

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得问题解答

查看资源