注意
点击此处下载完整的示例代码
(beta)使用 FX 构建简单的 CPU 性能分析器¶
创建于:2021 年 3 月 4 日 | 最后更新:2024 年 1 月 16 日 | 最后验证:未验证
作者:James Reed
在本教程中,我们将使用 FX 完成以下操作
以我们可以检查并收集有关代码结构和执行的统计信息的方式捕获 PyTorch Python 代码
构建一个小型类,该类将充当简单的性能“分析器”,从实际运行中收集有关模型每个部分的运行时统计信息。
在本教程中,我们将使用 torchvision ResNet18 模型进行演示。
import torch
import torch.fx
import torchvision.models as models
rn18 = models.resnet18()
rn18.eval()
ResNet(
(conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
(layer1): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(1): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer2): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer3): Sequential(
(0): BasicBlock(
(conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer4): Sequential(
(0): BasicBlock(
(conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
(fc): Linear(in_features=512, out_features=1000, bias=True)
)
现在我们有了模型,我们想更深入地了解其性能。也就是说,对于以下调用,模型的哪些部分花费的时间最长?
input = torch.randn(5, 3, 224, 224)
output = rn18(input)
回答该问题的常用方法是浏览程序源代码,添加在程序中各个点收集时间戳的代码,并比较这些时间戳之间的差异,以查看时间戳之间区域花费的时间。
该技术当然适用于 PyTorch 代码,但是,如果我们不必复制模型代码并对其进行编辑,尤其是我们没有编写的代码(例如此 torchvision 模型),那将更好。相反,我们将使用 FX 来自动化此“检测”过程,而无需修改任何源代码。
首先,让我们摆脱一些导入(我们将在后面的代码中使用所有这些)。
import statistics, tabulate, time
from typing import Any, Dict, List
from torch.fx import Interpreter
注意
tabulate
是一个外部库,它不是 PyTorch 的依赖项。我们将使用它来更轻松地可视化性能数据。请确保已从您喜欢的 Python 包源安装它。
使用符号追踪捕获模型¶
接下来,我们将使用 FX 的符号追踪机制来捕获模型定义,并将其放在我们可以操作和检查的数据结构中。
traced_rn18 = torch.fx.symbolic_trace(rn18)
print(traced_rn18.graph)
graph():
%x : torch.Tensor [num_users=1] = placeholder[target=x]
%conv1 : [num_users=1] = call_module[target=conv1](args = (%x,), kwargs = {})
%bn1 : [num_users=1] = call_module[target=bn1](args = (%conv1,), kwargs = {})
%relu : [num_users=1] = call_module[target=relu](args = (%bn1,), kwargs = {})
%maxpool : [num_users=2] = call_module[target=maxpool](args = (%relu,), kwargs = {})
%layer1_0_conv1 : [num_users=1] = call_module[target=layer1.0.conv1](args = (%maxpool,), kwargs = {})
%layer1_0_bn1 : [num_users=1] = call_module[target=layer1.0.bn1](args = (%layer1_0_conv1,), kwargs = {})
%layer1_0_relu : [num_users=1] = call_module[target=layer1.0.relu](args = (%layer1_0_bn1,), kwargs = {})
%layer1_0_conv2 : [num_users=1] = call_module[target=layer1.0.conv2](args = (%layer1_0_relu,), kwargs = {})
%layer1_0_bn2 : [num_users=1] = call_module[target=layer1.0.bn2](args = (%layer1_0_conv2,), kwargs = {})
%add : [num_users=1] = call_function[target=operator.add](args = (%layer1_0_bn2, %maxpool), kwargs = {})
%layer1_0_relu_1 : [num_users=2] = call_module[target=layer1.0.relu](args = (%add,), kwargs = {})
%layer1_1_conv1 : [num_users=1] = call_module[target=layer1.1.conv1](args = (%layer1_0_relu_1,), kwargs = {})
%layer1_1_bn1 : [num_users=1] = call_module[target=layer1.1.bn1](args = (%layer1_1_conv1,), kwargs = {})
%layer1_1_relu : [num_users=1] = call_module[target=layer1.1.relu](args = (%layer1_1_bn1,), kwargs = {})
%layer1_1_conv2 : [num_users=1] = call_module[target=layer1.1.conv2](args = (%layer1_1_relu,), kwargs = {})
%layer1_1_bn2 : [num_users=1] = call_module[target=layer1.1.bn2](args = (%layer1_1_conv2,), kwargs = {})
%add_1 : [num_users=1] = call_function[target=operator.add](args = (%layer1_1_bn2, %layer1_0_relu_1), kwargs = {})
%layer1_1_relu_1 : [num_users=2] = call_module[target=layer1.1.relu](args = (%add_1,), kwargs = {})
%layer2_0_conv1 : [num_users=1] = call_module[target=layer2.0.conv1](args = (%layer1_1_relu_1,), kwargs = {})
%layer2_0_bn1 : [num_users=1] = call_module[target=layer2.0.bn1](args = (%layer2_0_conv1,), kwargs = {})
%layer2_0_relu : [num_users=1] = call_module[target=layer2.0.relu](args = (%layer2_0_bn1,), kwargs = {})
%layer2_0_conv2 : [num_users=1] = call_module[target=layer2.0.conv2](args = (%layer2_0_relu,), kwargs = {})
%layer2_0_bn2 : [num_users=1] = call_module[target=layer2.0.bn2](args = (%layer2_0_conv2,), kwargs = {})
%layer2_0_downsample_0 : [num_users=1] = call_module[target=layer2.0.downsample.0](args = (%layer1_1_relu_1,), kwargs = {})
%layer2_0_downsample_1 : [num_users=1] = call_module[target=layer2.0.downsample.1](args = (%layer2_0_downsample_0,), kwargs = {})
%add_2 : [num_users=1] = call_function[target=operator.add](args = (%layer2_0_bn2, %layer2_0_downsample_1), kwargs = {})
%layer2_0_relu_1 : [num_users=2] = call_module[target=layer2.0.relu](args = (%add_2,), kwargs = {})
%layer2_1_conv1 : [num_users=1] = call_module[target=layer2.1.conv1](args = (%layer2_0_relu_1,), kwargs = {})
%layer2_1_bn1 : [num_users=1] = call_module[target=layer2.1.bn1](args = (%layer2_1_conv1,), kwargs = {})
%layer2_1_relu : [num_users=1] = call_module[target=layer2.1.relu](args = (%layer2_1_bn1,), kwargs = {})
%layer2_1_conv2 : [num_users=1] = call_module[target=layer2.1.conv2](args = (%layer2_1_relu,), kwargs = {})
%layer2_1_bn2 : [num_users=1] = call_module[target=layer2.1.bn2](args = (%layer2_1_conv2,), kwargs = {})
%add_3 : [num_users=1] = call_function[target=operator.add](args = (%layer2_1_bn2, %layer2_0_relu_1), kwargs = {})
%layer2_1_relu_1 : [num_users=2] = call_module[target=layer2.1.relu](args = (%add_3,), kwargs = {})
%layer3_0_conv1 : [num_users=1] = call_module[target=layer3.0.conv1](args = (%layer2_1_relu_1,), kwargs = {})
%layer3_0_bn1 : [num_users=1] = call_module[target=layer3.0.bn1](args = (%layer3_0_conv1,), kwargs = {})
%layer3_0_relu : [num_users=1] = call_module[target=layer3.0.relu](args = (%layer3_0_bn1,), kwargs = {})
%layer3_0_conv2 : [num_users=1] = call_module[target=layer3.0.conv2](args = (%layer3_0_relu,), kwargs = {})
%layer3_0_bn2 : [num_users=1] = call_module[target=layer3.0.bn2](args = (%layer3_0_conv2,), kwargs = {})
%layer3_0_downsample_0 : [num_users=1] = call_module[target=layer3.0.downsample.0](args = (%layer2_1_relu_1,), kwargs = {})
%layer3_0_downsample_1 : [num_users=1] = call_module[target=layer3.0.downsample.1](args = (%layer3_0_downsample_0,), kwargs = {})
%add_4 : [num_users=1] = call_function[target=operator.add](args = (%layer3_0_bn2, %layer3_0_downsample_1), kwargs = {})
%layer3_0_relu_1 : [num_users=2] = call_module[target=layer3.0.relu](args = (%add_4,), kwargs = {})
%layer3_1_conv1 : [num_users=1] = call_module[target=layer3.1.conv1](args = (%layer3_0_relu_1,), kwargs = {})
%layer3_1_bn1 : [num_users=1] = call_module[target=layer3.1.bn1](args = (%layer3_1_conv1,), kwargs = {})
%layer3_1_relu : [num_users=1] = call_module[target=layer3.1.relu](args = (%layer3_1_bn1,), kwargs = {})
%layer3_1_conv2 : [num_users=1] = call_module[target=layer3.1.conv2](args = (%layer3_1_relu,), kwargs = {})
%layer3_1_bn2 : [num_users=1] = call_module[target=layer3.1.bn2](args = (%layer3_1_conv2,), kwargs = {})
%add_5 : [num_users=1] = call_function[target=operator.add](args = (%layer3_1_bn2, %layer3_0_relu_1), kwargs = {})
%layer3_1_relu_1 : [num_users=2] = call_module[target=layer3.1.relu](args = (%add_5,), kwargs = {})
%layer4_0_conv1 : [num_users=1] = call_module[target=layer4.0.conv1](args = (%layer3_1_relu_1,), kwargs = {})
%layer4_0_bn1 : [num_users=1] = call_module[target=layer4.0.bn1](args = (%layer4_0_conv1,), kwargs = {})
%layer4_0_relu : [num_users=1] = call_module[target=layer4.0.relu](args = (%layer4_0_bn1,), kwargs = {})
%layer4_0_conv2 : [num_users=1] = call_module[target=layer4.0.conv2](args = (%layer4_0_relu,), kwargs = {})
%layer4_0_bn2 : [num_users=1] = call_module[target=layer4.0.bn2](args = (%layer4_0_conv2,), kwargs = {})
%layer4_0_downsample_0 : [num_users=1] = call_module[target=layer4.0.downsample.0](args = (%layer3_1_relu_1,), kwargs = {})
%layer4_0_downsample_1 : [num_users=1] = call_module[target=layer4.0.downsample.1](args = (%layer4_0_downsample_0,), kwargs = {})
%add_6 : [num_users=1] = call_function[target=operator.add](args = (%layer4_0_bn2, %layer4_0_downsample_1), kwargs = {})
%layer4_0_relu_1 : [num_users=2] = call_module[target=layer4.0.relu](args = (%add_6,), kwargs = {})
%layer4_1_conv1 : [num_users=1] = call_module[target=layer4.1.conv1](args = (%layer4_0_relu_1,), kwargs = {})
%layer4_1_bn1 : [num_users=1] = call_module[target=layer4.1.bn1](args = (%layer4_1_conv1,), kwargs = {})
%layer4_1_relu : [num_users=1] = call_module[target=layer4.1.relu](args = (%layer4_1_bn1,), kwargs = {})
%layer4_1_conv2 : [num_users=1] = call_module[target=layer4.1.conv2](args = (%layer4_1_relu,), kwargs = {})
%layer4_1_bn2 : [num_users=1] = call_module[target=layer4.1.bn2](args = (%layer4_1_conv2,), kwargs = {})
%add_7 : [num_users=1] = call_function[target=operator.add](args = (%layer4_1_bn2, %layer4_0_relu_1), kwargs = {})
%layer4_1_relu_1 : [num_users=1] = call_module[target=layer4.1.relu](args = (%add_7,), kwargs = {})
%avgpool : [num_users=1] = call_module[target=avgpool](args = (%layer4_1_relu_1,), kwargs = {})
%flatten : [num_users=1] = call_function[target=torch.flatten](args = (%avgpool, 1), kwargs = {})
%fc : [num_users=1] = call_module[target=fc](args = (%flatten,), kwargs = {})
return fc
这为我们提供了 ResNet18 模型的 Graph 表示。Graph 由一系列相互连接的节点组成。每个节点代表 Python 代码中的调用站点(无论是函数、模块还是方法),边缘(在每个节点上表示为 args
和 kwargs
)表示在这些调用站点之间传递的值。有关 Graph 表示和 FX 其余 API 的更多信息,请参见 FX 文档 https://pytorch.ac.cn/docs/master/fx.html。
创建性能分析解释器¶
接下来,我们将创建一个从 torch.fx.Interpreter
继承的类。尽管 symbolic_trace
生成的 GraphModule
会编译在您调用 GraphModule
时运行的 Python 代码,但运行 GraphModule
的另一种方法是逐个执行 Graph
中的每个 Node
。这就是 Interpreter
提供的功能:它逐节点解释图。
通过从 Interpreter
继承,我们可以覆盖各种功能并安装我们想要的性能分析行为。目标是拥有一个对象,我们可以将模型传递给该对象,调用模型 1 次或多次,然后获取有关模型和模型每个部分在这些运行期间花费的时间的统计信息。
让我们定义我们的 ProfilingInterpreter
类
class ProfilingInterpreter(Interpreter):
def __init__(self, mod : torch.nn.Module):
# Rather than have the user symbolically trace their model,
# we're going to do it in the constructor. As a result, the
# user can pass in any ``Module`` without having to worry about
# symbolic tracing APIs
gm = torch.fx.symbolic_trace(mod)
super().__init__(gm)
# We are going to store away two things here:
#
# 1. A list of total runtimes for ``mod``. In other words, we are
# storing away the time ``mod(...)`` took each time this
# interpreter is called.
self.total_runtime_sec : List[float] = []
# 2. A map from ``Node`` to a list of times (in seconds) that
# node took to run. This can be seen as similar to (1) but
# for specific sub-parts of the model.
self.runtimes_sec : Dict[torch.fx.Node, List[float]] = {}
######################################################################
# Next, let's override our first method: ``run()``. ``Interpreter``'s ``run``
# method is the top-level entry point for execution of the model. We will
# want to intercept this so that we can record the total runtime of the
# model.
def run(self, *args) -> Any:
# Record the time we started running the model
t_start = time.time()
# Run the model by delegating back into Interpreter.run()
return_val = super().run(*args)
# Record the time we finished running the model
t_end = time.time()
# Store the total elapsed time this model execution took in the
# ``ProfilingInterpreter``
self.total_runtime_sec.append(t_end - t_start)
return return_val
######################################################################
# Now, let's override ``run_node``. ``Interpreter`` calls ``run_node`` each
# time it executes a single node. We will intercept this so that we
# can measure and record the time taken for each individual call in
# the model.
def run_node(self, n : torch.fx.Node) -> Any:
# Record the time we started running the op
t_start = time.time()
# Run the op by delegating back into Interpreter.run_node()
return_val = super().run_node(n)
# Record the time we finished running the op
t_end = time.time()
# If we don't have an entry for this node in our runtimes_sec
# data structure, add one with an empty list value.
self.runtimes_sec.setdefault(n, [])
# Record the total elapsed time for this single invocation
# in the runtimes_sec data structure
self.runtimes_sec[n].append(t_end - t_start)
return return_val
######################################################################
# Finally, we are going to define a method (one which doesn't override
# any ``Interpreter`` method) that provides us a nice, organized view of
# the data we have collected.
def summary(self, should_sort : bool = False) -> str:
# Build up a list of summary information for each node
node_summaries : List[List[Any]] = []
# Calculate the mean runtime for the whole network. Because the
# network may have been called multiple times during profiling,
# we need to summarize the runtimes. We choose to use the
# arithmetic mean for this.
mean_total_runtime = statistics.mean(self.total_runtime_sec)
# For each node, record summary statistics
for node, runtimes in self.runtimes_sec.items():
# Similarly, compute the mean runtime for ``node``
mean_runtime = statistics.mean(runtimes)
# For easier understanding, we also compute the percentage
# time each node took with respect to the whole network.
pct_total = mean_runtime / mean_total_runtime * 100
# Record the node's type, name of the node, mean runtime, and
# percent runtime.
node_summaries.append(
[node.op, str(node), mean_runtime, pct_total])
# One of the most important questions to answer when doing performance
# profiling is "Which op(s) took the longest?". We can make this easy
# to see by providing sorting functionality in our summary view
if should_sort:
node_summaries.sort(key=lambda s: s[2], reverse=True)
# Use the ``tabulate`` library to create a well-formatted table
# presenting our summary information
headers : List[str] = [
'Op type', 'Op', 'Average runtime (s)', 'Pct total runtime'
]
return tabulate.tabulate(node_summaries, headers=headers)
注意
我们使用 Python 的 time.time
函数来提取挂钟时间戳并进行比较。这不是衡量性能的最准确方法,只能给我们一个一阶近似值。我们仅在本教程中出于演示目的使用这种简单技术。
调查 ResNet18 的性能¶
我们现在可以使用 ProfilingInterpreter
来检查我们的 ResNet18 模型的性能特征;
interp = ProfilingInterpreter(rn18)
interp.run(input)
print(interp.summary(True))
Op type Op Average runtime (s) Pct total runtime
------------- --------------------- --------------------- -------------------
call_module maxpool 0.00859714 8.02576
call_module conv1 0.00653529 6.10095
call_module layer4_0_conv2 0.00636888 5.94559
call_module layer4_1_conv1 0.00583315 5.44547
call_module layer4_1_conv2 0.00581217 5.42588
call_module layer1_0_conv1 0.00575733 5.37469
call_module layer1_1_conv2 0.00560737 5.23469
call_module layer2_1_conv1 0.00524116 4.89282
call_module layer1_1_conv1 0.00520563 4.85966
call_module layer3_1_conv2 0.00516272 4.81959
call_module layer3_1_conv1 0.00511885 4.77864
call_module layer3_0_conv2 0.0050416 4.70653
call_module layer2_1_conv2 0.00477362 4.45635
call_module layer1_0_conv2 0.00472641 4.41229
call_module layer2_0_conv2 0.00456476 4.26138
call_module layer4_0_conv1 0.00383401 3.57919
call_module layer3_0_conv1 0.00297451 2.77682
call_module layer2_0_conv1 0.00259256 2.42026
call_module bn1 0.00228763 2.13559
call_module layer2_0_downsample_0 0.00136423 1.27356
call_function add_1 0.000645161 0.602282
call_module layer3_0_downsample_0 0.000638962 0.596495
call_function add 0.000619888 0.57869
call_module layer4_0_downsample_0 0.000584602 0.545749
call_module relu 0.000487328 0.454939
call_function add_3 0.000329733 0.307818
call_module layer1_1_bn2 0.000267506 0.249727
call_module layer1_0_bn1 0.000254631 0.237708
call_module layer1_1_bn1 0.000223398 0.208551
call_module layer1_0_bn2 0.000217915 0.203432
call_module fc 0.000207663 0.193861
call_module layer2_0_downsample_1 0.00019908 0.185848
call_module layer2_1_bn1 0.000183105 0.170936
call_module layer2_0_bn2 0.000172615 0.161143
call_module layer2_1_bn2 0.000161886 0.151127
call_module avgpool 0.000159979 0.149346
call_module layer2_0_bn1 0.000159264 0.148679
call_module layer3_1_bn2 0.000156403 0.146008
call_module layer3_1_bn1 0.000155449 0.145118
call_module layer4_1_bn2 0.000144243 0.134657
call_module layer3_0_downsample_1 0.000143051 0.133544
call_module layer4_0_bn2 0.000140667 0.131318
call_module layer4_1_bn1 0.000140429 0.131095
call_module layer3_0_bn2 0.000138998 0.12976
call_module layer1_0_relu 0.00013876 0.129537
call_module layer1_1_relu_1 0.000132561 0.123751
call_module layer4_0_bn1 0.000131369 0.122638
call_module layer3_0_bn1 0.000128984 0.120412
call_module layer1_0_relu_1 0.000127077 0.118631
call_module layer1_1_relu 0.000126839 0.118409
call_module layer4_0_downsample_1 0.000124693 0.116406
call_function add_2 0.000115156 0.107503
call_function add_7 0.000108719 0.101493
call_function add_6 9.94205e-05 0.0928129
call_module layer2_0_relu_1 9.77516e-05 0.0912549
call_module layer2_1_relu 9.75132e-05 0.0910323
call_module layer2_0_relu 9.53674e-05 0.0890292
call_module layer2_1_relu_1 9.48906e-05 0.088584
call_module layer4_1_relu 9.27448e-05 0.0865809
call_module layer4_0_relu 9.2268e-05 0.0861357
call_function add_5 8.79765e-05 0.0821294
call_module layer3_1_relu 8.08239e-05 0.0754522
call_module layer3_0_relu 8.05855e-05 0.0752296
call_module layer4_1_relu_1 7.89165e-05 0.0736716
call_module layer4_0_relu_1 7.82013e-05 0.0730039
call_function add_4 7.79629e-05 0.0727813
call_module layer3_0_relu_1 7.77245e-05 0.0725588
call_module layer3_1_relu_1 7.72476e-05 0.0721136
call_function flatten 4.24385e-05 0.039618
placeholder x 2.6226e-05 0.024483
output output 1.90735e-05 0.0178058
这里有两件事我们应该指出
MaxPool2d
占用最多的时间。这是一个已知问题:https://github.com/pytorch/pytorch/issues/51393BatchNorm2d 也占用大量时间。我们可以继续这种思路,并在使用 FX 的 Conv-BN 融合 教程中对其进行优化。
结论¶
正如我们所看到的,使用 FX 我们可以轻松捕获 PyTorch 程序(即使是我们没有源代码的程序!),并以机器可解释的格式使用它进行分析,例如我们在此处进行的性能分析。FX 为使用 PyTorch 程序开辟了令人兴奋的可能性世界。
最后,由于 FX 仍处于 beta 阶段,我们很乐意听到您对使用它的任何反馈。请随时使用 PyTorch 论坛 (https://discuss.pytorch.org/) 和问题跟踪器 (https://github.com/pytorch/pytorch/issues) 提供您可能有的任何反馈。
脚本的总运行时间: ( 0 分钟 0.486 秒)