• 教程 >
  • (测试版) 使用 FX 构建简单的 CPU 性能分析器
快捷方式

(测试版) 使用 FX 构建简单的 CPU 性能分析器

作者: James Reed

在本教程中,我们将使用 FX 执行以下操作

  1. 以我们可以检查并收集有关代码结构和执行的统计信息的方式捕获 PyTorch Python 代码

  2. 构建一个小型类,该类将充当一个简单的性能“分析器”,从实际运行中收集有关模型每个部分的运行时统计信息。

在本教程中,我们将使用 torchvision ResNet18 模型进行演示。

import torch
import torch.fx
import torchvision.models as models

rn18 = models.resnet18()
rn18.eval()
ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer2): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer3): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer4): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc): Linear(in_features=512, out_features=1000, bias=True)
)

现在我们有了模型,我们想要更深入地了解其性能。也就是说,对于以下调用,模型的哪些部分花费的时间最长?

input = torch.randn(5, 3, 224, 224)
output = rn18(input)

回答这个问题的一个常见方法是遍历程序源代码,添加收集程序中各个点的 timestamps 的代码,并比较这些 timestamps 之间的差异,以查看 timestamps 之间的区域花费的时间。

这种技术当然适用于 PyTorch 代码,但是如果我们不需要复制模型代码并进行编辑,特别是我们没有编写过的代码(比如这个 torchvision 模型),那就更好了。相反,我们将使用 FX 来自动执行这个“检测”过程,而无需修改任何源代码。

首先,让我们将一些导入内容排除在外(我们将在后面的代码中使用这些内容)。

import statistics, tabulate, time
from typing import Any, Dict, List
from torch.fx import Interpreter

注意

tabulate 是一个外部库,它不是 PyTorch 的依赖项。我们将使用它来更容易地可视化性能数据。请确保您已从您喜欢的 Python 包源中安装它。

使用符号跟踪捕获模型

接下来,我们将使用 FX 的符号跟踪机制来捕获模型定义,将其存储在一个可操作和检查的数据结构中。

traced_rn18 = torch.fx.symbolic_trace(rn18)
print(traced_rn18.graph)
graph():
    %x : torch.Tensor [num_users=1] = placeholder[target=x]
    %conv1 : [num_users=1] = call_module[target=conv1](args = (%x,), kwargs = {})
    %bn1 : [num_users=1] = call_module[target=bn1](args = (%conv1,), kwargs = {})
    %relu : [num_users=1] = call_module[target=relu](args = (%bn1,), kwargs = {})
    %maxpool : [num_users=2] = call_module[target=maxpool](args = (%relu,), kwargs = {})
    %layer1_0_conv1 : [num_users=1] = call_module[target=layer1.0.conv1](args = (%maxpool,), kwargs = {})
    %layer1_0_bn1 : [num_users=1] = call_module[target=layer1.0.bn1](args = (%layer1_0_conv1,), kwargs = {})
    %layer1_0_relu : [num_users=1] = call_module[target=layer1.0.relu](args = (%layer1_0_bn1,), kwargs = {})
    %layer1_0_conv2 : [num_users=1] = call_module[target=layer1.0.conv2](args = (%layer1_0_relu,), kwargs = {})
    %layer1_0_bn2 : [num_users=1] = call_module[target=layer1.0.bn2](args = (%layer1_0_conv2,), kwargs = {})
    %add : [num_users=1] = call_function[target=operator.add](args = (%layer1_0_bn2, %maxpool), kwargs = {})
    %layer1_0_relu_1 : [num_users=2] = call_module[target=layer1.0.relu](args = (%add,), kwargs = {})
    %layer1_1_conv1 : [num_users=1] = call_module[target=layer1.1.conv1](args = (%layer1_0_relu_1,), kwargs = {})
    %layer1_1_bn1 : [num_users=1] = call_module[target=layer1.1.bn1](args = (%layer1_1_conv1,), kwargs = {})
    %layer1_1_relu : [num_users=1] = call_module[target=layer1.1.relu](args = (%layer1_1_bn1,), kwargs = {})
    %layer1_1_conv2 : [num_users=1] = call_module[target=layer1.1.conv2](args = (%layer1_1_relu,), kwargs = {})
    %layer1_1_bn2 : [num_users=1] = call_module[target=layer1.1.bn2](args = (%layer1_1_conv2,), kwargs = {})
    %add_1 : [num_users=1] = call_function[target=operator.add](args = (%layer1_1_bn2, %layer1_0_relu_1), kwargs = {})
    %layer1_1_relu_1 : [num_users=2] = call_module[target=layer1.1.relu](args = (%add_1,), kwargs = {})
    %layer2_0_conv1 : [num_users=1] = call_module[target=layer2.0.conv1](args = (%layer1_1_relu_1,), kwargs = {})
    %layer2_0_bn1 : [num_users=1] = call_module[target=layer2.0.bn1](args = (%layer2_0_conv1,), kwargs = {})
    %layer2_0_relu : [num_users=1] = call_module[target=layer2.0.relu](args = (%layer2_0_bn1,), kwargs = {})
    %layer2_0_conv2 : [num_users=1] = call_module[target=layer2.0.conv2](args = (%layer2_0_relu,), kwargs = {})
    %layer2_0_bn2 : [num_users=1] = call_module[target=layer2.0.bn2](args = (%layer2_0_conv2,), kwargs = {})
    %layer2_0_downsample_0 : [num_users=1] = call_module[target=layer2.0.downsample.0](args = (%layer1_1_relu_1,), kwargs = {})
    %layer2_0_downsample_1 : [num_users=1] = call_module[target=layer2.0.downsample.1](args = (%layer2_0_downsample_0,), kwargs = {})
    %add_2 : [num_users=1] = call_function[target=operator.add](args = (%layer2_0_bn2, %layer2_0_downsample_1), kwargs = {})
    %layer2_0_relu_1 : [num_users=2] = call_module[target=layer2.0.relu](args = (%add_2,), kwargs = {})
    %layer2_1_conv1 : [num_users=1] = call_module[target=layer2.1.conv1](args = (%layer2_0_relu_1,), kwargs = {})
    %layer2_1_bn1 : [num_users=1] = call_module[target=layer2.1.bn1](args = (%layer2_1_conv1,), kwargs = {})
    %layer2_1_relu : [num_users=1] = call_module[target=layer2.1.relu](args = (%layer2_1_bn1,), kwargs = {})
    %layer2_1_conv2 : [num_users=1] = call_module[target=layer2.1.conv2](args = (%layer2_1_relu,), kwargs = {})
    %layer2_1_bn2 : [num_users=1] = call_module[target=layer2.1.bn2](args = (%layer2_1_conv2,), kwargs = {})
    %add_3 : [num_users=1] = call_function[target=operator.add](args = (%layer2_1_bn2, %layer2_0_relu_1), kwargs = {})
    %layer2_1_relu_1 : [num_users=2] = call_module[target=layer2.1.relu](args = (%add_3,), kwargs = {})
    %layer3_0_conv1 : [num_users=1] = call_module[target=layer3.0.conv1](args = (%layer2_1_relu_1,), kwargs = {})
    %layer3_0_bn1 : [num_users=1] = call_module[target=layer3.0.bn1](args = (%layer3_0_conv1,), kwargs = {})
    %layer3_0_relu : [num_users=1] = call_module[target=layer3.0.relu](args = (%layer3_0_bn1,), kwargs = {})
    %layer3_0_conv2 : [num_users=1] = call_module[target=layer3.0.conv2](args = (%layer3_0_relu,), kwargs = {})
    %layer3_0_bn2 : [num_users=1] = call_module[target=layer3.0.bn2](args = (%layer3_0_conv2,), kwargs = {})
    %layer3_0_downsample_0 : [num_users=1] = call_module[target=layer3.0.downsample.0](args = (%layer2_1_relu_1,), kwargs = {})
    %layer3_0_downsample_1 : [num_users=1] = call_module[target=layer3.0.downsample.1](args = (%layer3_0_downsample_0,), kwargs = {})
    %add_4 : [num_users=1] = call_function[target=operator.add](args = (%layer3_0_bn2, %layer3_0_downsample_1), kwargs = {})
    %layer3_0_relu_1 : [num_users=2] = call_module[target=layer3.0.relu](args = (%add_4,), kwargs = {})
    %layer3_1_conv1 : [num_users=1] = call_module[target=layer3.1.conv1](args = (%layer3_0_relu_1,), kwargs = {})
    %layer3_1_bn1 : [num_users=1] = call_module[target=layer3.1.bn1](args = (%layer3_1_conv1,), kwargs = {})
    %layer3_1_relu : [num_users=1] = call_module[target=layer3.1.relu](args = (%layer3_1_bn1,), kwargs = {})
    %layer3_1_conv2 : [num_users=1] = call_module[target=layer3.1.conv2](args = (%layer3_1_relu,), kwargs = {})
    %layer3_1_bn2 : [num_users=1] = call_module[target=layer3.1.bn2](args = (%layer3_1_conv2,), kwargs = {})
    %add_5 : [num_users=1] = call_function[target=operator.add](args = (%layer3_1_bn2, %layer3_0_relu_1), kwargs = {})
    %layer3_1_relu_1 : [num_users=2] = call_module[target=layer3.1.relu](args = (%add_5,), kwargs = {})
    %layer4_0_conv1 : [num_users=1] = call_module[target=layer4.0.conv1](args = (%layer3_1_relu_1,), kwargs = {})
    %layer4_0_bn1 : [num_users=1] = call_module[target=layer4.0.bn1](args = (%layer4_0_conv1,), kwargs = {})
    %layer4_0_relu : [num_users=1] = call_module[target=layer4.0.relu](args = (%layer4_0_bn1,), kwargs = {})
    %layer4_0_conv2 : [num_users=1] = call_module[target=layer4.0.conv2](args = (%layer4_0_relu,), kwargs = {})
    %layer4_0_bn2 : [num_users=1] = call_module[target=layer4.0.bn2](args = (%layer4_0_conv2,), kwargs = {})
    %layer4_0_downsample_0 : [num_users=1] = call_module[target=layer4.0.downsample.0](args = (%layer3_1_relu_1,), kwargs = {})
    %layer4_0_downsample_1 : [num_users=1] = call_module[target=layer4.0.downsample.1](args = (%layer4_0_downsample_0,), kwargs = {})
    %add_6 : [num_users=1] = call_function[target=operator.add](args = (%layer4_0_bn2, %layer4_0_downsample_1), kwargs = {})
    %layer4_0_relu_1 : [num_users=2] = call_module[target=layer4.0.relu](args = (%add_6,), kwargs = {})
    %layer4_1_conv1 : [num_users=1] = call_module[target=layer4.1.conv1](args = (%layer4_0_relu_1,), kwargs = {})
    %layer4_1_bn1 : [num_users=1] = call_module[target=layer4.1.bn1](args = (%layer4_1_conv1,), kwargs = {})
    %layer4_1_relu : [num_users=1] = call_module[target=layer4.1.relu](args = (%layer4_1_bn1,), kwargs = {})
    %layer4_1_conv2 : [num_users=1] = call_module[target=layer4.1.conv2](args = (%layer4_1_relu,), kwargs = {})
    %layer4_1_bn2 : [num_users=1] = call_module[target=layer4.1.bn2](args = (%layer4_1_conv2,), kwargs = {})
    %add_7 : [num_users=1] = call_function[target=operator.add](args = (%layer4_1_bn2, %layer4_0_relu_1), kwargs = {})
    %layer4_1_relu_1 : [num_users=1] = call_module[target=layer4.1.relu](args = (%add_7,), kwargs = {})
    %avgpool : [num_users=1] = call_module[target=avgpool](args = (%layer4_1_relu_1,), kwargs = {})
    %flatten : [num_users=1] = call_function[target=torch.flatten](args = (%avgpool, 1), kwargs = {})
    %fc : [num_users=1] = call_module[target=fc](args = (%flatten,), kwargs = {})
    return fc

这将为我们提供 ResNet18 模型的图形表示。图形由一系列相互连接的节点组成。每个节点代表 Python 代码中的调用点(无论是函数、模块还是方法),边(表示为每个节点上的 argskwargs)代表在这些调用点之间传递的值。有关图形表示和 FX API 的更多信息,请参阅 FX 文档 https://pytorch.ac.cn/docs/master/fx.html

创建性能分析解释器

接下来,我们将创建一个从 torch.fx.Interpreter 继承的类。虽然 symbolic_trace 生成的 GraphModule 会编译在调用 GraphModule 时运行的 Python 代码,但运行 GraphModule 的另一种方法是逐个执行 Graph 中的每个 Node。这就是 Interpreter 提供的功能:它逐节点解释图形。

通过继承 Interpreter,我们可以覆盖各种功能并安装所需的性能分析行为。目标是拥有一个对象,我们可以将模型传递给它,调用模型一次或多次,然后获取有关模型及其各个部分在这些运行期间所用时间的统计信息。

让我们定义我们的 ProfilingInterpreter

class ProfilingInterpreter(Interpreter):
    def __init__(self, mod : torch.nn.Module):
        # Rather than have the user symbolically trace their model,
        # we're going to do it in the constructor. As a result, the
        # user can pass in any ``Module`` without having to worry about
        # symbolic tracing APIs
        gm = torch.fx.symbolic_trace(mod)
        super().__init__(gm)

        # We are going to store away two things here:
        #
        # 1. A list of total runtimes for ``mod``. In other words, we are
        #    storing away the time ``mod(...)`` took each time this
        #    interpreter is called.
        self.total_runtime_sec : List[float] = []
        # 2. A map from ``Node`` to a list of times (in seconds) that
        #    node took to run. This can be seen as similar to (1) but
        #    for specific sub-parts of the model.
        self.runtimes_sec : Dict[torch.fx.Node, List[float]] = {}

    ######################################################################
    # Next, let's override our first method: ``run()``. ``Interpreter``'s ``run``
    # method is the top-level entry point for execution of the model. We will
    # want to intercept this so that we can record the total runtime of the
    # model.

    def run(self, *args) -> Any:
        # Record the time we started running the model
        t_start = time.time()
        # Run the model by delegating back into Interpreter.run()
        return_val = super().run(*args)
        # Record the time we finished running the model
        t_end = time.time()
        # Store the total elapsed time this model execution took in the
        # ``ProfilingInterpreter``
        self.total_runtime_sec.append(t_end - t_start)
        return return_val

    ######################################################################
    # Now, let's override ``run_node``. ``Interpreter`` calls ``run_node`` each
    # time it executes a single node. We will intercept this so that we
    # can measure and record the time taken for each individual call in
    # the model.

    def run_node(self, n : torch.fx.Node) -> Any:
        # Record the time we started running the op
        t_start = time.time()
        # Run the op by delegating back into Interpreter.run_node()
        return_val = super().run_node(n)
        # Record the time we finished running the op
        t_end = time.time()
        # If we don't have an entry for this node in our runtimes_sec
        # data structure, add one with an empty list value.
        self.runtimes_sec.setdefault(n, [])
        # Record the total elapsed time for this single invocation
        # in the runtimes_sec data structure
        self.runtimes_sec[n].append(t_end - t_start)
        return return_val

    ######################################################################
    # Finally, we are going to define a method (one which doesn't override
    # any ``Interpreter`` method) that provides us a nice, organized view of
    # the data we have collected.

    def summary(self, should_sort : bool = False) -> str:
        # Build up a list of summary information for each node
        node_summaries : List[List[Any]] = []
        # Calculate the mean runtime for the whole network. Because the
        # network may have been called multiple times during profiling,
        # we need to summarize the runtimes. We choose to use the
        # arithmetic mean for this.
        mean_total_runtime = statistics.mean(self.total_runtime_sec)

        # For each node, record summary statistics
        for node, runtimes in self.runtimes_sec.items():
            # Similarly, compute the mean runtime for ``node``
            mean_runtime = statistics.mean(runtimes)
            # For easier understanding, we also compute the percentage
            # time each node took with respect to the whole network.
            pct_total = mean_runtime / mean_total_runtime * 100
            # Record the node's type, name of the node, mean runtime, and
            # percent runtime.
            node_summaries.append(
                [node.op, str(node), mean_runtime, pct_total])

        # One of the most important questions to answer when doing performance
        # profiling is "Which op(s) took the longest?". We can make this easy
        # to see by providing sorting functionality in our summary view
        if should_sort:
            node_summaries.sort(key=lambda s: s[2], reverse=True)

        # Use the ``tabulate`` library to create a well-formatted table
        # presenting our summary information
        headers : List[str] = [
            'Op type', 'Op', 'Average runtime (s)', 'Pct total runtime'
        ]
        return tabulate.tabulate(node_summaries, headers=headers)

注意

我们使用 Python 的 time.time 函数来获取挂钟时间戳并进行比较。这不是测量性能最准确的方法,只会给我们一个一阶近似值。我们仅出于本教程演示的目的使用这种简单技术。

调查 ResNet18 的性能

现在,我们可以使用 ProfilingInterpreter 检查 ResNet18 模型的性能特征;

interp = ProfilingInterpreter(rn18)
interp.run(input)
print(interp.summary(True))
Op type        Op                       Average runtime (s)    Pct total runtime
-------------  ---------------------  ---------------------  -------------------
call_module    maxpool                          0.00844693            11.135
call_module    conv1                            0.00659585             8.69485
call_module    layer1_0_conv1                   0.00542212             7.14759
call_module    layer3_0_conv2                   0.00409031             5.39197
call_module    layer4_0_conv2                   0.00383186             5.05128
call_module    layer1_0_conv2                   0.00360656             4.75427
call_module    layer1_1_conv1                   0.0034225              4.51164
call_module    layer4_1_conv2                   0.00321341             4.23601
call_module    layer4_1_conv1                   0.00310612             4.09458
call_module    layer2_1_conv1                   0.00305367             4.02543
call_module    layer3_1_conv1                   0.00288701             3.80574
call_module    layer3_1_conv2                   0.00280714             3.70046
call_module    layer1_1_conv2                   0.00274491             3.61843
call_module    layer2_1_conv2                   0.00266814             3.51722
call_module    layer2_0_conv2                   0.00243831             3.21425
call_module    bn1                              0.00231838             3.05616
call_module    layer4_0_conv1                   0.00226879             2.99079
call_module    layer2_0_conv1                   0.00185394             2.44392
call_module    layer3_0_conv1                   0.00175261             2.31035
call_module    layer2_0_downsample_0            0.00104856             1.38225
call_module    layer4_0_downsample_0            0.000541687            0.714068
call_function  add_1                            0.000507355            0.66881
call_function  add                              0.000486135            0.640838
call_module    layer3_0_downsample_0            0.000485182            0.639581
call_module    relu                             0.000395536            0.521408
call_function  add_3                            0.000303268            0.399777
call_module    layer1_0_bn1                     0.000268221            0.353577
call_module    layer1_1_bn1                     0.000190735            0.251432
call_module    layer1_0_bn2                     0.000177145            0.233518
call_module    fc                               0.000173569            0.228803
call_module    layer1_1_bn2                     0.00016284             0.21466
call_module    layer1_0_relu                    0.000154972            0.204289
call_module    avgpool                          0.000151634            0.199889
call_module    layer2_0_bn1                     0.00014019             0.184803
call_module    layer2_0_downsample_1            0.00013876             0.182917
call_module    layer2_1_bn1                     0.000138283            0.182288
call_module    layer4_1_bn2                     0.000134706            0.177574
call_module    layer2_1_bn2                     0.000132084            0.174117
call_module    layer3_0_bn2                     0.000131607            0.173488
call_module    layer3_1_bn2                     0.000127792            0.16846
call_module    layer3_1_bn1                     0.000127554            0.168145
call_module    layer2_0_bn2                     0.000127077            0.167517
call_module    layer4_0_downsample_1            0.000122786            0.16186
call_module    layer4_0_bn2                     0.000121355            0.159974
call_module    layer3_0_bn1                     0.000119209            0.157145
call_module    layer4_0_bn1                     0.000116825            0.154002
call_module    layer3_0_downsample_1            0.000116348            0.153374
call_module    layer4_1_bn1                     0.000115871            0.152745
call_module    layer1_0_relu_1                  0.000112534            0.148345
call_module    layer1_1_relu                    0.000110626            0.145831
call_module    layer1_1_relu_1                  0.00010705             0.141116
call_function  add_2                            8.67844e-05            0.114402
call_module    layer4_1_relu                    8.46386e-05            0.111573
call_module    layer2_0_relu                    8.44002e-05            0.111259
call_module    layer4_0_relu                    8.44002e-05            0.111259
call_module    layer2_1_relu                    8.10623e-05            0.106859
call_module    layer2_1_relu_1                  8.08239e-05            0.106544
call_module    layer2_0_relu_1                  7.77245e-05            0.102459
call_module    layer3_0_relu                    7.43866e-05            0.0980586
call_module    layer4_0_relu_1                  7.41482e-05            0.0977443
call_function  add_7                            7.39098e-05            0.09743
call_module    layer3_1_relu                    7.36713e-05            0.0971158
call_function  add_6                            7.27177e-05            0.0958586
call_function  add_5                            7.1764e-05             0.0946014
call_module    layer3_0_relu_1                  7.12872e-05            0.0939729
call_module    layer4_1_relu_1                  7.12872e-05            0.0939729
call_module    layer3_1_relu_1                  7.03335e-05            0.0927157
call_function  add_4                            6.65188e-05            0.087687
call_function  flatten                          3.93391e-05            0.0518579
placeholder    x                                2.52724e-05            0.0333148
output         output                           1.69277e-05            0.0223146

这里有两点需要说明

结论

正如我们所见,使用 FX,我们可以轻松地将 PyTorch 程序(即使是我们没有源代码的程序!)捕获为机器可解释的格式,并将其用于分析,例如我们这里所做的性能分析。FX 为使用 PyTorch 程序开辟了一个充满可能性的世界。

最后,由于 FX 仍处于测试阶段,我们很乐意听到您对使用它的任何反馈。请随时使用 PyTorch 论坛 (https://discuss.pytorch.org/) 和问题跟踪器 (https://github.com/pytorch/pytorch/issues) 提供您的任何反馈。

脚本的总运行时间:(0 分钟 0.465 秒)

由 Sphinx-Gallery 生成的画廊

文档

访问 PyTorch 的全面的开发者文档

查看文档

教程

获取针对初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源