注意
点击此处下载完整的示例代码
性能分析你的 PyTorch Module¶
创建时间:2020 年 12 月 30 日 | 最后更新:2024 年 1 月 19 日 | 最后验证:2024 年 11 月 5 日
PyTorch 包含一个 profiler(性能分析器)API,可用于识别代码中各种 PyTorch 操作的时间和内存开销。Profiler 可以轻松地集成到你的代码中,并且结果可以打印为表格或以 JSON 跟踪文件的形式返回。
注意
Profiler 支持多线程模型。Profiler 在与操作相同的线程中运行,但它也会对可能在另一个线程中运行的子算子进行性能分析。并发运行的 profiler 将限定在其各自的线程范围内,以防止结果混杂。
注意
PyTorch 1.8 引入了新的 API,该 API 将在未来版本中取代旧的 profiler API。请访问此页面查看新 API。
前往这个代码示例,以便更快地了解 Profiler API 的用法。
import torch
import numpy as np
from torch import nn
import torch.autograd.profiler as profiler
使用 Profiler 进行性能调试¶
Profiler 可用于识别模型中的性能瓶颈。在本示例中,我们构建了一个执行两个子任务的自定义 module:
对输入进行线性变换,以及
使用变换结果获取 mask 张量上的索引。
我们使用 profiler.record_function("label")
将每个子任务的代码包装在单独的带标签的上下文管理器中。在 profiler 输出中,子任务中所有操作的聚合性能指标将显示在其对应的标签下。
请注意,使用 Profiler 会产生一定的开销,最好仅用于代码调查。如果你正在进行运行时基准测试,请记住将其移除。
class MyModule(nn.Module):
def __init__(self, in_features: int, out_features: int, bias: bool = True):
super(MyModule, self).__init__()
self.linear = nn.Linear(in_features, out_features, bias)
def forward(self, input, mask):
with profiler.record_function("LINEAR PASS"):
out = self.linear(input)
with profiler.record_function("MASK INDICES"):
threshold = out.sum(axis=1).mean().item()
hi_idx = np.argwhere(mask.cpu().numpy() > threshold)
hi_idx = torch.from_numpy(hi_idx).cuda()
return out, hi_idx
对前向传播进行性能分析¶
我们初始化随机的输入张量和 mask 张量,以及模型。
在运行 profiler 之前,我们先对 CUDA 进行预热,以确保精确的性能基准测试。我们将 module 的前向传播包装在 profiler.profile
上下文管理器中。with_stack=True
参数会在跟踪中附加操作的文件名和行号。
警告
with_stack=True
会产生额外的开销,更适合用于代码调查。如果你正在进行性能基准测试,请记住将其移除。
model = MyModule(500, 10).cuda()
input = torch.rand(128, 500).cuda()
mask = torch.rand((500, 500, 500), dtype=torch.double).cuda()
# warm-up
model(input, mask)
with profiler.profile(with_stack=True, profile_memory=True) as prof:
out, idx = model(input, mask)
打印 profiler 结果¶
最后,我们打印 profiler 结果。profiler.key_averages
按算子名称聚合结果,也可以选择按输入形状和/或堆栈跟踪事件进行聚合。按输入形状分组有助于识别模型使用了哪些张量形状。
在这里,我们使用 group_by_stack_n=5
,它按操作及其回溯(截断到最近的 5 个事件)聚合运行时,并按事件注册的顺序显示。表格也可以通过传递 sort_by
参数进行排序(请参考文档了解有效的排序键)。
注意
在 notebook 中运行 profiler 时,你在堆栈跟踪中可能会看到诸如 <ipython-input-18-193a910735e8>(13): forward
的条目,而不是文件名。这些对应于 <notebook-cell>(line number): calling-function
。
print(prof.key_averages(group_by_stack_n=5).table(sort_by='self_cpu_time_total', row_limit=5))
"""
(Some columns are omitted)
------------- ------------ ------------ ------------ ---------------------------------
Name Self CPU % Self CPU Self CPU Mem Source Location
------------- ------------ ------------ ------------ ---------------------------------
MASK INDICES 87.88% 5.212s -953.67 Mb /mnt/xarfuse/.../torch/au
<ipython-input-...>(10): forward
/mnt/xarfuse/.../torch/nn
<ipython-input-...>(9): <module>
/mnt/xarfuse/.../IPython/
aten::copy_ 12.07% 715.848ms 0 b <ipython-input-...>(12): forward
/mnt/xarfuse/.../torch/nn
<ipython-input-...>(9): <module>
/mnt/xarfuse/.../IPython/
/mnt/xarfuse/.../IPython/
LINEAR PASS 0.01% 350.151us -20 b /mnt/xarfuse/.../torch/au
<ipython-input-...>(7): forward
/mnt/xarfuse/.../torch/nn
<ipython-input-...>(9): <module>
/mnt/xarfuse/.../IPython/
aten::addmm 0.00% 293.342us 0 b /mnt/xarfuse/.../torch/nn
/mnt/xarfuse/.../torch/nn
/mnt/xarfuse/.../torch/nn
<ipython-input-...>(8): forward
/mnt/xarfuse/.../torch/nn
aten::mean 0.00% 235.095us 0 b <ipython-input-...>(11): forward
/mnt/xarfuse/.../torch/nn
<ipython-input-...>(9): <module>
/mnt/xarfuse/.../IPython/
/mnt/xarfuse/.../IPython/
----------------------------- ------------ ---------- ----------------------------------
Self CPU time total: 5.931s
"""
改进内存性能¶
请注意,内存和时间方面开销最大的操作位于 forward (10)
,它代表 MASK INDICES 中的操作。我们先尝试解决内存消耗问题。可以看到,第 12 行的 .to()
操作消耗了 953.67 Mb。此操作将 mask
复制到 CPU。mask
使用 torch.double
数据类型初始化。我们能否通过将其转换为 torch.float
来减少内存占用?
model = MyModule(500, 10).cuda()
input = torch.rand(128, 500).cuda()
mask = torch.rand((500, 500, 500), dtype=torch.float).cuda()
# warm-up
model(input, mask)
with profiler.profile(with_stack=True, profile_memory=True) as prof:
out, idx = model(input, mask)
print(prof.key_averages(group_by_stack_n=5).table(sort_by='self_cpu_time_total', row_limit=5))
"""
(Some columns are omitted)
----------------- ------------ ------------ ------------ --------------------------------
Name Self CPU % Self CPU Self CPU Mem Source Location
----------------- ------------ ------------ ------------ --------------------------------
MASK INDICES 93.61% 5.006s -476.84 Mb /mnt/xarfuse/.../torch/au
<ipython-input-...>(10): forward
/mnt/xarfuse/ /torch/nn
<ipython-input-...>(9): <module>
/mnt/xarfuse/.../IPython/
aten::copy_ 6.34% 338.759ms 0 b <ipython-input-...>(12): forward
/mnt/xarfuse/.../torch/nn
<ipython-input-...>(9): <module>
/mnt/xarfuse/.../IPython/
/mnt/xarfuse/.../IPython/
aten::as_strided 0.01% 281.808us 0 b <ipython-input-...>(11): forward
/mnt/xarfuse/.../torch/nn
<ipython-input-...>(9): <module>
/mnt/xarfuse/.../IPython/
/mnt/xarfuse/.../IPython/
aten::addmm 0.01% 275.721us 0 b /mnt/xarfuse/.../torch/nn
/mnt/xarfuse/.../torch/nn
/mnt/xarfuse/.../torch/nn
<ipython-input-...>(8): forward
/mnt/xarfuse/.../torch/nn
aten::_local 0.01% 268.650us 0 b <ipython-input-...>(11): forward
_scalar_dense /mnt/xarfuse/.../torch/nn
<ipython-input-...>(9): <module>
/mnt/xarfuse/.../IPython/
/mnt/xarfuse/.../IPython/
----------------- ------------ ------------ ------------ --------------------------------
Self CPU time total: 5.347s
"""
此操作的 CPU 内存占用减少了一半。
改进时间性能¶
虽然消耗的时间也减少了一些,但仍然太高。事实证明,将矩阵从 CUDA 复制到 CPU 非常耗时!forward (12)
中的 aten::copy_
算子将 mask
复制到 CPU,以便可以使用 NumPy 的 argwhere
函数。forward(13)
中的 aten::copy_
将数组作为张量复制回 CUDA。如果我们在此处改用 torch
函数 nonzero()
,就可以消除这两个操作。
class MyModule(nn.Module):
def __init__(self, in_features: int, out_features: int, bias: bool = True):
super(MyModule, self).__init__()
self.linear = nn.Linear(in_features, out_features, bias)
def forward(self, input, mask):
with profiler.record_function("LINEAR PASS"):
out = self.linear(input)
with profiler.record_function("MASK INDICES"):
threshold = out.sum(axis=1).mean()
hi_idx = (mask > threshold).nonzero(as_tuple=True)
return out, hi_idx
model = MyModule(500, 10).cuda()
input = torch.rand(128, 500).cuda()
mask = torch.rand((500, 500, 500), dtype=torch.float).cuda()
# warm-up
model(input, mask)
with profiler.profile(with_stack=True, profile_memory=True) as prof:
out, idx = model(input, mask)
print(prof.key_averages(group_by_stack_n=5).table(sort_by='self_cpu_time_total', row_limit=5))
"""
(Some columns are omitted)
-------------- ------------ ------------ ------------ ---------------------------------
Name Self CPU % Self CPU Self CPU Mem Source Location
-------------- ------------ ------------ ------------ ---------------------------------
aten::gt 57.17% 129.089ms 0 b <ipython-input-...>(12): forward
/mnt/xarfuse/.../torch/nn
<ipython-input-...>(25): <module>
/mnt/xarfuse/.../IPython/
/mnt/xarfuse/.../IPython/
aten::nonzero 37.38% 84.402ms 0 b <ipython-input-...>(12): forward
/mnt/xarfuse/.../torch/nn
<ipython-input-...>(25): <module>
/mnt/xarfuse/.../IPython/
/mnt/xarfuse/.../IPython/
INDEX SCORE 3.32% 7.491ms -119.21 Mb /mnt/xarfuse/.../torch/au
<ipython-input-...>(10): forward
/mnt/xarfuse/.../torch/nn
<ipython-input-...>(25): <module>
/mnt/xarfuse/.../IPython/
aten::as_strided 0.20% 441.587us 0 b <ipython-input-...>(12): forward
/mnt/xarfuse/.../torch/nn
<ipython-input-...>(25): <module>
/mnt/xarfuse/.../IPython/
/mnt/xarfuse/.../IPython/
aten::nonzero
_numpy 0.18% 395.602us 0 b <ipython-input-...>(12): forward
/mnt/xarfuse/.../torch/nn
<ipython-input-...>(25): <module>
/mnt/xarfuse/.../IPython/
/mnt/xarfuse/.../IPython/
-------------- ------------ ------------ ------------ ---------------------------------
Self CPU time total: 225.801ms
"""