注意
点击此处下载完整示例代码
定时器快速入门¶
创建日期:2021年4月1日 | 最后更新:2024年1月19日 | 最后验证:未验证
在本教程中,我们将介绍 torch.utils.benchmark.Timer 的主要 API。PyTorch Timer 基于 timeit.Timer API,并进行了一些 PyTorch 特定的修改。本教程不需要熟悉内置的 Timer 类,但我们假设读者熟悉性能工作的基本原理。
有关更全面的性能调优教程,请参阅 PyTorch Benchmark。
- 目录
1. 定义 Timer¶
Timer 用作任务定义。
from torch.utils.benchmark import Timer
timer = Timer(
# The computation which will be run in a loop and timed.
stmt="x * y",
# `setup` will be run before calling the measurement loop, and is used to
# populate any state which is needed by `stmt`
setup="""
x = torch.ones((128,))
y = torch.ones((128,))
""",
# Alternatively, ``globals`` can be used to pass variables from the outer scope.
#
# globals={
# "x": torch.ones((128,)),
# "y": torch.ones((128,)),
# },
# Control the number of threads that PyTorch uses. (Default: 1)
num_threads=1,
)
2. 实际时间(Wall time):Timer.blocked_autorange(...)
¶
此方法将处理诸如选择合适的重复次数、固定线程数以及提供方便的结果表示形式等细节。
# Measurement objects store the results of multiple repeats, and provide
# various utility features.
from torch.utils.benchmark import Measurement
m: Measurement = timer.blocked_autorange(min_run_time=1)
print(m)
<torch.utils.benchmark.utils.common.Measurement object at 0x7f1929a38ed0>
x * y
setup:
x = torch.ones((128,))
y = torch.ones((128,))
Median: 2.34 us
IQR: 0.07 us (2.31 to 2.38)
424 measurements, 1000 runs per measurement, 1 thread
3. C++ 代码片段¶
from torch.utils.benchmark import Language
cpp_timer = Timer(
"x * y;",
"""
auto x = torch::ones({128});
auto y = torch::ones({128});
""",
language=Language.CPP,
)
print(cpp_timer.blocked_autorange(min_run_time=1))
<torch.utils.benchmark.utils.common.Measurement object at 0x7f192b019ed0>
x * y;
setup:
auto x = torch::ones({128});
auto y = torch::ones({128});
Median: 1.21 us
IQR: 0.03 us (1.20 to 1.23)
83 measurements, 10000 runs per measurement, 1 thread
毫不意外,C++ 代码片段既更快,变异性也更低。
4. 指令计数:Timer.collect_callgrind(...)
¶
对于深入研究,Timer.collect_callgrind
封装了 Callgrind 以收集指令计数。这些指令计数非常有用,因为它们提供了对代码片段如何运行的细粒度且确定性(在 Python 的情况下噪声非常低)的洞察。
from torch.utils.benchmark import CallgrindStats, FunctionCounts
stats: CallgrindStats = cpp_timer.collect_callgrind()
print(stats)
<torch.utils.benchmark.utils.valgrind_wrapper.timer_interface.CallgrindStats object at 0x7f1929a35850>
x * y;
setup:
auto x = torch::ones({128});
auto y = torch::ones({128});
All Noisy symbols removed
Instructions: 563600 563600
Baseline: 0 0
100 runs per measurement, 1 thread
5. 指令计数:深入探究¶
CallgrindStats
的字符串表示形式与 Measurement 类似。Noisy symbols 是一个 Python 概念(移除 CPython 解释器中已知会产生噪声的调用)。
然而,为了进行更详细的分析,我们会希望查看特定的调用。CallgrindStats.stats()
返回一个 FunctionCounts
对象,以便更轻松地实现这一点。概念上,FunctionCounts
可以被视为一个包含一些实用方法的对偶元组,其中每一对是 (指令数,文件路径和函数名)。
- 关于路径的说明
通常人们不关心绝对路径。例如,乘法调用的完整路径和函数名是这样的:
/the/prefix/to/your/pytorch/install/dir/pytorch/build/aten/src/ATen/core/TensorMethods.cpp:at::Tensor::mul(at::Tensor const&) const [/the/path/to/your/conda/install/miniconda3/envs/ab_ref/lib/python3.7/site-packages/torch/lib/libtorch_cpu.so]
when in reality, all of the information that we're interested in can be
represented in:
build/aten/src/ATen/core/TensorMethods.cpp:at::Tensor::mul(at::Tensor const&) const
``CallgrindStats.as_standardized()`` makes a best effort to strip low signal
portions of the file path, as well as the shared object and is generally
recommended.
inclusive_stats = stats.as_standardized().stats(inclusive=False)
print(inclusive_stats[:10])
torch.utils.benchmark.utils.valgrind_wrapper.timer_interface.FunctionCounts object at 0x7f192a6dfd90>
47264 ???:_int_free
25963 ???:_int_malloc
19900 build/../aten/src/ATen/TensorIter ... (at::TensorIteratorConfig const&)
18000 ???:__tls_get_addr
13500 ???:malloc
11300 build/../c10/util/SmallVector.h:a ... (at::TensorIteratorConfig const&)
10345 ???:_int_memalign
10000 build/../aten/src/ATen/TensorIter ... (at::TensorIteratorConfig const&)
9200 ???:free
8000 build/../c10/util/SmallVector.h:a ... IteratorBase::get_strides() const
Total: 173472
这仍然需要消化很多信息。让我们使用 FunctionCounts.transform 方法来修剪一些函数路径,并丢弃调用的函数。这样做时,任何冲突(例如 foo.h:a() 和 foo.h:b() 都将映射到 foo.h)的计数将被累加。
import os
import re
def group_by_file(fn_name: str):
if fn_name.startswith("???"):
fn_dir, fn_file = fn_name.split(":")[:2]
else:
fn_dir, fn_file = os.path.split(fn_name.split(":")[0])
fn_dir = re.sub("^.*build/../", "", fn_dir)
fn_dir = re.sub("^.*torch/", "torch/", fn_dir)
return f"{fn_dir:<15} {fn_file}"
print(inclusive_stats.transform(group_by_file)[:10])
<torch.utils.benchmark.utils.valgrind_wrapper.timer_interface.FunctionCounts object at 0x7f192995d750>
118200 aten/src/ATen TensorIterator.cpp
65000 c10/util SmallVector.h
47264 ??? _int_free
25963 ??? _int_malloc
20900 c10/util intrusive_ptr.h
18000 ??? __tls_get_addr
15900 c10/core TensorImpl.h
15100 c10/core CPUAllocator.cpp
13500 ??? malloc
12500 c10/core TensorImpl.cpp
Total: 352327
6. 使用 Callgrind
进行 A/B 测试¶
指令计数最有用的特性之一是它们允许对计算进行细粒度的比较,这在分析性能时至关重要。
为了看到这一点,让我们将两个大小为 128 的张量相乘与 {128} x {1} 的乘法进行比较,后者将广播第二个张量:
result = {a0 * b0, a1 * b0, …, a127 * b0}
broadcasting_stats = Timer(
"x * y;",
"""
auto x = torch::ones({128});
auto y = torch::ones({1});
""",
language=Language.CPP,
).collect_callgrind().as_standardized().stats(inclusive=False)
我们经常希望对两个不同的环境进行 A/B 测试。(例如测试一个 PR,或实验编译标志。)这非常简单,因为 CallgrindStats
、FunctionCounts
和 Measurement 都是可序列化的。只需保存每个环境的测量结果,然后在单个进程中加载它们进行分析。
import pickle
# Let's round trip `broadcasting_stats` just to show that we can.
broadcasting_stats = pickle.loads(pickle.dumps(broadcasting_stats))
# And now to diff the two tasks:
delta = broadcasting_stats - inclusive_stats
def extract_fn_name(fn: str):
"""Trim everything except the function name."""
fn = ":".join(fn.split(":")[1:])
return re.sub(r"\(.+\)", "(...)", fn)
# We use `.transform` to make the diff readable:
print(delta.transform(extract_fn_name))
<torch.utils.benchmark.utils.valgrind_wrapper.timer_interface.FunctionCounts object at 0x7f192995d750>
17600 at::TensorIteratorBase::compute_strides(...)
12700 at::TensorIteratorBase::allocate_or_resize_outputs()
10200 c10::SmallVectorImpl<long>::operator=(...)
7400 at::infer_size(...)
6200 at::TensorIteratorBase::invert_perm(...) const
6064 _int_free
5100 at::TensorIteratorBase::reorder_dimensions()
4300 malloc
4300 at::TensorIteratorBase::compatible_stride(...) const
...
-28 _int_memalign
-100 c10::impl::check_tensor_options_and_extract_memory_format(...)
-300 __memcmp_avx2_movbe
-400 at::detail::empty_cpu(...)
-1100 at::TensorIteratorBase::numel() const
-1300 void at::native::(...)
-2400 c10::TensorImpl::is_contiguous(...) const
-6100 at::TensorIteratorBase::compute_fast_setup_type(...)
-22600 at::TensorIteratorBase::fast_set_up(...)
Total: 58091
因此,广播版本每次调用额外需要 580 条指令(请记住我们每个样本收集了 100 次运行),大约增加了 10%。其中有很多 TensorIterator
调用,所以让我们深入研究这些调用。FunctionCounts.filter
使这变得容易。
print(delta.transform(extract_fn_name).filter(lambda fn: "TensorIterator" in fn))
<torch.utils.benchmark.utils.valgrind_wrapper.timer_interface.FunctionCounts object at 0x7f19299544d0>
17600 at::TensorIteratorBase::compute_strides(...)
12700 at::TensorIteratorBase::allocate_or_resize_outputs()
6200 at::TensorIteratorBase::invert_perm(...) const
5100 at::TensorIteratorBase::reorder_dimensions()
4300 at::TensorIteratorBase::compatible_stride(...) const
4000 at::TensorIteratorBase::compute_shape(...)
2300 at::TensorIteratorBase::coalesce_dimensions()
1600 at::TensorIteratorBase::build(...)
-1100 at::TensorIteratorBase::numel() const
-6100 at::TensorIteratorBase::compute_fast_setup_type(...)
-22600 at::TensorIteratorBase::fast_set_up(...)
Total: 24000
这清楚地说明了正在发生的事情:TensorIterator
设置中有一个快速路径,但在 {128} x {1} 的情况下我们错过了它,不得不执行更通用的分析,这更加昂贵。过滤掉的最突出的调用是 c10::SmallVectorImpl<long>::operator=(…),它也是更通用设置的一部分。
7. 总结¶
总之,使用 Timer.blocked_autorange 收集实际时间(wall times)。如果计时变异性过高,请增加 min_run_time,或者如果方便的话,转移到 C++ 代码片段。
对于细粒度分析,使用 Timer.collect_callgrind 测量指令计数,并使用 FunctionCounts.(__add__ / __sub__ / transform / filter) 对其进行切片和整理。
8. 脚注¶
- 隐式
import torch
如果 globals 不包含“torch”,Timer 将自动填充它。这意味着
Timer("torch.empty(())")
将起作用。(尽管其他导入应该放在 setup 中,例如Timer("np.zeros(())", "import numpy as np")
)
REL_WITH_DEB_INFO
为了提供关于 PyTorch 内部执行的完整信息,
Callgrind
需要访问 C++ 调试符号。这通过在构建 PyTorch 时设置REL_WITH_DEB_INFO=1
来实现。否则,函数调用将是不透明的。(生成的CallgrindStats
会在缺少调试符号时发出警告。)
脚本总运行时间: ( 0 分 0.000 秒)