基准测试实用程序 - torch.utils.benchmark¶
- class torch.utils.benchmark.Timer(stmt='pass', setup='pass', global_setup='', timer=<built-in function perf_counter>, globals=None, label=None, sub_label=None, description=None, env=None, num_threads=1, language=Language.PYTHON)[source]¶
用于测量 PyTorch 语句执行时间的辅助类。
有关如何使用此类的完整教程,请参见:https://pytorch.ac.cn/tutorials/recipes/recipes/benchmark.html
PyTorch Timer 基于 timeit.Timer(实际上在内部使用 timeit.Timer),但存在几个关键区别
- 运行时感知
Timer 将执行预热(很重要,因为 PyTorch 的某些元素是延迟初始化的),设置线程池大小,以便比较是苹果对苹果的,并在必要时同步异步 CUDA 函数。
- 关注副本
在测量代码,特别是复杂的内核/模型时,运行到运行的变化是一个重要的混淆因素。预期所有测量都应包括副本以量化噪声并允许中位数计算,这比平均值更稳健。为此,此类通过概念上合并 timeit.Timer.repeat 和 timeit.Timer.autorange 来偏离 timeit API。(确切的算法在方法文档字符串中讨论。)timeit 方法在不需要自适应策略的情况下被复制。
- 可选元数据
在定义 Timer 时,可以可选地指定 label、sub_label、description 和 env。(稍后定义)这些字段包含在结果对象的表示中,并且由 Compare 类用于对结果进行分组和显示以进行比较。
- 指令计数
除了挂钟时间之外,Timer 还可以在 Callgrind 下运行语句并报告执行的指令。
与 timeit.Timer 构造函数参数直接类似
stmt、setup、timer、globals
PyTorch Timer 特定的构造函数参数
label、sub_label、description、env、num_threads
- 参数
stmt (str) – 要在循环中运行并计时代码段。
setup (str) – 可选设置代码。用于定义在 stmt 中使用的变量
global_setup (str) – (仅限 C++) 代码,放置在文件的顶层,用于 #include 语句等。
timer (Callable[[], float]) – 返回当前时间的可调用对象。如果 PyTorch 在没有 CUDA 的情况下构建,或者没有 GPU,则默认为 timeit.default_timer;否则,它将在测量时间之前同步 CUDA。
globals (Optional[Dict[str, Any]]) – 一个 dict,它在执行 stmt 时定义全局变量。这是提供 stmt 需要变量的另一种方法。
label (Optional[str]) – 总结 stmt 的字符串。例如,如果 stmt 是 “torch.nn.functional.relu(torch.add(x, 1, out=out))”,则可以将标签设置为 “ReLU(x + 1)” 以提高可读性。
提供补充信息以区分具有相同 stmt 或标签的测量结果。例如,在我们上面的示例中,sub_label 可能是 “float” 或 “int”,因此可以轻松区分:“ReLU(x + 1):(float)”
”ReLU(x + 1):(int)” 在打印测量结果或使用 Compare 进行总结时。
用于区分具有相同标签和子标签的测量值的字符串。 description 的主要用途是向 Compare 信号数据列。 例如,可以根据输入大小设置它,以创建以下形式的表
| n=1 | n=4 | ... ------------- ... ReLU(x + 1): (float) | ... | ... | ... ReLU(x + 1): (int) | ... | ... | ...
使用 Compare。 在打印测量值时也会包含它。
env (可选[str]) – 此标签表示,否则相同的任务在不同的环境中运行,因此不具有等效性,例如,在对内核进行 A/B 测试时。 Compare 将在合并重复运行时,将具有不同 env 规范的测量值视为不同的。
num_threads (int) – 执行 stmt 时 PyTorch 线程池的大小。 单线程性能非常重要,因为它既是关键的推断工作负载,也是算法效率内在指标的良好指标,因此默认设置为 1。 这与默认的 PyTorch 线程池大小形成对比,后者试图利用所有核心。
- adaptive_autorange(threshold=0.1, *, min_run_time=0.01, max_run_time=10.0, callback=None)[source]¶
类似于 blocked_autorange,但也检查测量值的变异性,并重复执行,直到 iqr/median 小于 threshold 或达到 max_run_time。
从高级别来看,adaptive_autorange 执行以下伪代码
`setup` times = [] while times.sum < max_run_time start = timer() for _ in range(block_size): `stmt` times.append(timer() - start) enough_data = len(times)>3 and times.sum > min_run_time small_iqr=times.iqr/times.mean<threshold if enough_data and small_iqr: break
- 参数
- 返回
一个 Measurement 对象,其中包含测量的运行时间和重复次数,可用于计算统计数据。(平均值、中位数等)
- 返回类型
- blocked_autorange(callback=None, min_run_time=0.2)[source]¶
测量许多重复,同时将计时器开销降至最低。
从高级别来看,blocked_autorange 执行以下伪代码
`setup` total_time = 0 while total_time < min_run_time start = timer() for _ in range(block_size): `stmt` total_time += (timer() - start)
请注意内部循环中的变量 block_size。 块大小的选择对测量质量很重要,必须平衡两个相互竞争的目标
较小的块大小会导致更多重复,通常会获得更好的统计数据。
较大的块大小更好地摊销了 timer 调用的成本,并导致偏差较小的测量。 这很重要,因为 CUDA 同步时间是非平凡的(顺序从个位数到两位数的微秒),否则会使测量产生偏差。
blocked_autorange 通过运行预热阶段来设置 block_size,增加 block_size 直到计时器开销小于总体计算的 0.1%。 此值将用于主测量循环。
- 返回
一个 Measurement 对象,其中包含测量的运行时间和重复次数,可用于计算统计数据。(平均值、中位数等)
- 返回类型
- collect_callgrind(number: int, *, repeats: None, collect_baseline: bool, retain_out_file: bool) CallgrindStats [source]¶
- collect_callgrind(number: int, *, repeats: int, collect_baseline: bool, retain_out_file: bool) Tuple[CallgrindStats, ...]
使用 Callgrind 收集指令计数。
与墙上时间不同,指令计数是确定性的(除了程序本身的非确定性和 Python 解释器带来的少量抖动)。 这使得它们成为详细性能分析的理想选择。 此方法在单独的进程中运行 stmt,以便 Valgrind 可以对程序进行检测。 性能会因为检测而严重下降,但是,可以通过以下事实来缓解这种情况:通常只需要少量迭代即可获得良好的测量。
为了使用此方法,必须安装 valgrind、callgrind_control 和 callgrind_annotate。
因为在调用方(此进程)和 stmt 执行之间存在进程边界,所以 globals 不能包含任意内存中数据结构。(与计时方法不同)相反,globals 限于内置函数、nn.Modules 和 TorchScripted 函数/模块,以减少来自序列化和后续反序列化的意外因素。 GlobalsBridge 类对此主题提供了更多详细信息。 请特别注意 nn.Modules:它们依赖于 pickle,您可能需要在 setup 中添加一个导入才能让它们正确传输。
默认情况下,将为一个空语句收集和缓存一个配置文件,以指示有多少指令来自驱动 stmt 的 Python 循环。
- 返回
一个 CallgrindStats 对象,它提供指令计数和一些用于分析和操作结果的基本工具。
- timeit(number=1000000)[source]¶
镜像 timeit.Timer.timeit() 的语义。
执行主语句 (stmt) number 次。 https://docs.pythonlang.cn/3/library/timeit.html#timeit.Timer.timeit
- 返回类型
- class torch.utils.benchmark.Measurement(number_per_run, raw_times, task_spec, metadata=None)[source]¶
计时器测量的结果。
此类存储给定语句的一个或多个测量值。 它可以序列化,并为下游使用者提供几种便利方法(包括详细的 __repr__)。
- class torch.utils.benchmark.CallgrindStats(task_spec, number_per_run, built_with_debug_symbols, baseline_inclusive_stats, baseline_exclusive_stats, stmt_inclusive_stats, stmt_exclusive_stats, stmt_callgrind_out)[source]¶
Callgrind 由计时器收集的结果的顶级容器。
操作通常使用 FunctionCounts 类进行,该类是通过调用 CallgrindStats.stats(…) 获得的。还提供了一些便利方法;其中最重要的是 CallgrindStats.as_standardized()。
- as_standardized()[source]¶
从函数字符串中剥离库名称和一些前缀。
比较两组不同的指令计数时,一个绊脚石可能是路径前缀。Callgrind 在报告函数时包含完整的路径(应如此)。但是,这在比较配置文件时会导致问题。如果 Python 或 PyTorch 等关键组件在两个配置文件中构建在不同的位置,这会导致类似以下内容:
23234231 /tmp/first_build_dir/thing.c:foo(...) 9823794 /tmp/first_build_dir/thing.c:bar(...) ... 53453 .../aten/src/Aten/...:function_that_actually_changed(...) ... -9823794 /tmp/second_build_dir/thing.c:bar(...) -23234231 /tmp/second_build_dir/thing.c:foo(...)
剥离前缀可以通过规范化字符串和在比较时导致等效调用点的更好取消来改善此问题。
- 返回类型
- delta(other, inclusive=False)[source]¶
比较两组计数。
收集指令计数的一个常见原因是确定特定更改对执行某些工作单位所需的指令数量的影响。如果更改增加了该数量,下一个合乎逻辑的问题是“为什么”。这通常涉及查看代码的哪个部分增加了指令计数。此函数使该过程自动化,因此可以轻松地在包含和排他基础上比较计数。
- 返回类型
- class torch.utils.benchmark.FunctionCounts(_data, inclusive, truncate_rows=True, _linewidth=None)[source]¶
用于操作 Callgrind 结果的容器。
- 它支持
加法和减法以组合或比较结果。
类似元组的索引。
一个 denoise 函数,它剥离已知为非确定性和非常嘈杂的 CPython 调用。
两种高阶方法(filter 和 transform)用于自定义操作。
- denoise()[source]¶
删除已知的嘈杂指令。
CPython 解释器中的几条指令相当嘈杂。这些指令涉及 Python 用于映射变量名的 Unicode 到字典查找。FunctionCounts 通常是一个与内容无关的容器,但对于获得可靠结果而言,这足够重要,因此需要一个例外。
- 返回类型
- class torch.utils.benchmark.Compare(results)[source]¶
辅助类,用于以格式化的表格显示许多测量的结果。
表格格式基于
torch.utils.benchmark.Timer
(description, label, sub_label, num_threads 等) 中提供的字段信息。可以使用
print()
直接打印表格,也可以将其强制转换为 str。有关如何使用此类的完整教程,请参见:https://pytorch.ac.cn/tutorials/recipes/recipes/benchmark.html
- 参数
results (List[Measurement]) – 要显示的 Measurment 列表。