• 文档 >
  • 基准测试工具 - torch.utils.benchmark
快捷方式

基准测试工具 - torch.utils.benchmark

class torch.utils.benchmark.Timer(stmt='pass', setup='pass', global_setup='', timer=<built-in function perf_counter>, globals=None, label=None, sub_label=None, description=None, env=None, num_threads=1, language=Language.PYTHON)[源代码][源代码]

用于测量 PyTorch 语句执行时间的辅助类。

有关如何使用此类的完整教程,请参阅:https://pytorch.ac.cn/tutorials/recipes/recipes/benchmark.html

PyTorch Timer 基于 timeit.Timer(实际上在内部使用了 timeit.Timer),但有一些关键区别:

  1. 运行时感知

    Timer 将执行预热(这很重要,因为 PyTorch 的某些元素是延迟初始化的),设置线程池大小以确保比较具有可比性,并在必要时同步异步 CUDA 函数。

  2. 侧重于重复测量

    在测量代码,特别是复杂的内核/模型时,运行间的变异是一个显著的混淆因素。期望所有测量都应包含重复测量,以量化噪声并允许计算中位数,中位数比均值更稳健。为此,此类在概念上合并了 timeit.Timer.repeattimeit.Timer.autorange,从而偏离了 timeit API。(具体算法在方法文档字符串中讨论。)对于不希望采用自适应策略的情况,timeit 方法被复制过来。

  3. 可选元数据

    定义 Timer 时,可以选择指定 labelsub_labeldescriptionenv。(稍后定义)这些字段包含在结果对象的表示中,并由 Compare 类用于分组和显示结果进行比较。

  4. 指令计数

    除了 wall time 外,Timer 还可以在 Callgrind 下运行语句并报告执行的指令数。

直接类似于 timeit.Timer 构造函数参数

stmtsetuptimerglobals

PyTorch Timer 特有的构造函数参数

labelsub_labeldescriptionenvnum_threads

参数
  • stmt (str) – 要在循环中运行并计时的代码片段。

  • setup (str) – 可选的设置代码。用于定义 stmt 中使用的变量。

  • global_setup (str) – (仅限 C++)放置在文件的顶层用于 #include 语句等内容的代码。

  • timer (Callable[[], float]) – 一个返回当前时间的 Callable。如果 PyTorch 在没有 CUDA 的情况下构建或不存在 GPU,则默认使用 timeit.default_timer;否则,它将在测量时间之前同步 CUDA。

  • globals (Optional[dict[str, Any]]) – 一个字典,定义了执行 stmt 时使用的全局变量。这是提供 stmt 所需变量的另一种方法。

  • label (Optional[str]) – 概括 stmt 的字符串。例如,如果 stmt 是“torch.nn.functional.relu(torch.add(x, 1, out=out))”,可以将 label 设置为“ReLU(x + 1)”以提高可读性。

  • sub_label (Optional[str]) –

    提供补充信息,以区分 stmt 或 label 相同的测量。例如,在上面的例子中,sub_label 可以是“float”或“int”,以便于区分:“ReLU(x + 1): (float)”

    “ReLU(x + 1): (int)” 当打印 Measurements 或使用 Compare 总结时。

  • description (Optional[str]) –

    用于区分 label 和 sub_label 相同的测量的字符串。description 的主要用途是向 Compare 指示数据列。例如,可以根据输入大小来设置它,以创建如下形式的表格:

                            | n=1 | n=4 | ...
                            ------------- ...
    ReLU(x + 1): (float)    | ... | ... | ...
    ReLU(x + 1): (int)      | ... | ... | ...
    

    使用 Compare。它在打印 Measurement 时也会包含。

  • env (Optional[str]) – 此标记表示在不同环境中运行了本应相同的任务,因此它们不等价,例如在对内核进行 A/B 测试时。Compare 在合并重复运行时,会将具有不同 env 规范的 Measurement 视为不同。

  • num_threads (int) – 执行 stmt 时 PyTorch 线程池的大小。单线程性能很重要,因为它既是关键的推理工作负载,也是衡量内在算法效率的良好指标,因此默认设置为 1。这与尝试利用所有核心的默认 PyTorch 线程池大小形成对比。

adaptive_autorange(threshold=0.1, *, min_run_time=0.01, max_run_time=10.0, callback=None)[源代码][源代码]

类似于 blocked_autorange,但也会检查测量的变异性,并重复直到 iqr/median 小于 threshold 或达到 max_run_time

在高层面上,adaptive_autorange 执行以下伪代码:

`setup`

times = []
while times.sum < max_run_time
    start = timer()
    for _ in range(block_size):
        `stmt`
    times.append(timer() - start)

    enough_data = len(times)>3 and times.sum > min_run_time
    small_iqr=times.iqr/times.mean<threshold

    if enough_data and small_iqr:
        break
参数
  • threshold (float) – iqr/median 阈值停止值

  • min_run_time (float) – 在检查 threshold 之前所需的总运行时间

  • max_run_time (float) – 无论 threshold 如何,所有测量的总运行时间

返回

一个 Measurement 对象,包含测量的运行时间和重复次数,可用于计算统计数据(均值、中位数等)。

返回类型

Measurement

blocked_autorange(callback=None, min_run_time=0.2)[源代码][源代码]

测量多个重复,同时将计时器开销降至最低。

在高层面上,blocked_autorange 执行以下伪代码:

`setup`

total_time = 0
while total_time < min_run_time
    start = timer()
    for _ in range(block_size):
        `stmt`
    total_time += (timer() - start)

请注意内层循环中的变量 block_size。块大小的选择对于测量质量很重要,必须平衡两个相互竞争的目标:

  1. 较小的块大小会产生更多的重复测量,通常会获得更好的统计数据。

  2. 较大的块大小可以更好地分摊计时器调用的成本,从而获得偏差较小的测量结果。这很重要,因为 CUDA 同步时间不可忽略(大约为个位数到低两位数微秒),否则会使测量产生偏差。

blocked_autorange 通过运行预热期来设置 block_size,增加块大小直到计时器开销小于总计算量的 0.1%。然后将此值用于主测量循环。

返回

一个 Measurement 对象,包含测量的运行时间和重复次数,可用于计算统计数据(均值、中位数等)。

返回类型

Measurement

collect_callgrind(number: int, *, repeats: None, collect_baseline: bool, retain_out_file: bool) CallgrindStats[源代码][源代码]
collect_callgrind(number: int, *, repeats: int, collect_baseline: bool, retain_out_file: bool) tuple[torch.utils.benchmark.utils.valgrind_wrapper.timer_interface.CallgrindStats, ...]

使用 Callgrind 收集指令计数。

与 wall time 不同,指令计数是确定性的(程序的非确定性以及 Python 解释器带来的少量抖动除外)。这使得它们非常适合进行详细的性能分析。此方法在单独的进程中运行 stmt,以便 Valgrind 可以对程序进行插装。由于插装,性能会严重下降,但通过较少的迭代次数通常足以获得良好的测量结果,可以缓解此问题。

要使用此方法,必须安装 valgrindcallgrind_controlcallgrind_annotate

由于调用者(当前进程)和 stmt 执行之间存在进程边界,因此全局变量不能包含任意内存数据结构。(与计时方法不同)相反,全局变量仅限于内置类型、nn.Modules 和 TorchScript 化函数/模块,以减少序列化和后续反序列化带来的意外因素。GlobalsBridge 类提供了有关此主题的更多详细信息。请特别注意 nn.Modules:它们依赖于 pickle,并且您可能需要在 setup 中添加导入才能正确传输。

默认情况下,将收集并缓存一个空语句的性能分析结果,以指示驱动 stmt 的 Python 循环有多少指令。

返回

一个 CallgrindStats 对象,提供指令计数以及一些用于分析和操作结果的基本功能。

timeit(number=1000000)[源代码][源代码]

镜像 timeit.Timer.timeit() 的语义。

执行主语句(stmtnumber 次。https://docs.pythonlang.cn/3/library/timeit.html#timeit.Timer.timeit

返回类型

Measurement

class torch.utils.benchmark.Measurement(number_per_run, raw_times, task_spec, metadata=None)[源代码][源代码]

Timer 测量结果。

此类存储给定语句的一个或多个测量结果。它是可序列化的,并为下游使用者提供了多种便利方法(包括详细的 __repr__)。

static merge(measurements)[源代码][源代码]

合并重复测量的便利方法。

Merge 将时间外推到 number_per_run=1,并且不会传输任何元数据。(因为它在重复测量之间可能有所不同)

返回类型

list['Measurement']

property significant_figures: int

近似的有效数字估计。

此属性旨在提供一种方便的方式来估计测量的精度。它仅使用四分位数区域来估计统计数据,以尝试减轻尾部偏差,并使用静态 z 值 1.645,因为它不期望用于小的 n 值,因此 z 可以近似 t

有效数字估计与 trim_sigfig 方法结合使用,以提供更易于人工理解的数据摘要。__repr__ 不使用此方法;它只显示原始值。有效数字估计旨在用于 Compare

class torch.utils.benchmark.CallgrindStats(task_spec, number_per_run, built_with_debug_symbols, baseline_inclusive_stats, baseline_exclusive_stats, stmt_inclusive_stats, stmt_exclusive_stats, stmt_callgrind_out)[源代码][源代码]

Timer 收集的 Callgrind 结果的顶层容器。

操作通常使用 FunctionCounts 类完成,该类通过调用 CallgrindStats.stats(...) 获得。还提供了一些便利方法;最重要的是 CallgrindStats.as_standardized()

as_standardized()[源代码][源代码]

从函数字符串中去除库名称和某些前缀。

比较两组不同的指令计数时,一个障碍可能是路径前缀。Callgrind 在报告函数时会包含完整的文件路径(这是应该的)。然而,这在对比性能分析结果时可能会导致问题。如果在两个性能分析结果中,像 Python 或 PyTorch 这样的关键组件是在不同的位置构建的,可能会出现类似以下情况:

23234231 /tmp/first_build_dir/thing.c:foo(...)
 9823794 /tmp/first_build_dir/thing.c:bar(...)
  ...
   53453 .../aten/src/Aten/...:function_that_actually_changed(...)
  ...
 -9823794 /tmp/second_build_dir/thing.c:bar(...)
-23234231 /tmp/second_build_dir/thing.c:foo(...)

去除前缀可以通过规范化字符串并在对比时更好地抵消等效调用点来缓解此问题。

返回类型

CallgrindStats

counts(*, denoise=False)[源代码][源代码]

返回执行的指令总数。

有关 denoise 参数的说明,请参阅 FunctionCounts.denoise()

返回类型

int

delta(other, inclusive=False)[源代码][源代码]

对比两组计数。

收集指令计数的一个常见原因是确定特定更改对执行某个工作单元所需的指令数产生的影响。如果更改增加了该数字,下一个合乎逻辑的问题就是“为什么”。这通常涉及查看代码的哪个部分增加了指令计数。此函数自动化了这一过程,以便可以轻松地基于包含和排除方式对比计数。

返回类型

FunctionCounts

stats(inclusive=False)[源代码][源代码]

返回详细的函数计数。

从概念上讲,返回的 FunctionCounts 可以被视为 (计数, 路径和函数名) 元组的元组。

inclusive 与 callgrind 的语义相符。如果为 True,则计数包括子函数执行的指令。inclusive=True 对于识别代码中的热点很有用;inclusive=False 对于减少对比两次不同运行的计数时的噪声很有用。(有关更多详细信息,请参阅 CallgrindStats.delta(...))

返回类型

FunctionCounts

class torch.utils.benchmark.FunctionCounts(_data, inclusive, truncate_rows=True, _linewidth=None)[源代码][源代码]

用于操作 Callgrind 结果的容器。

它支持:
  1. 加法和减法以组合或对比结果。

  2. 类似元组的索引。

  3. 一个去噪函数,用于去除已知为非确定性且噪音很大的 CPython 调用。

  4. 两个用于自定义操作的高阶方法(filtertransform)。

denoise()[source][source]

移除已知的噪声指令。

CPython 解释器中的一些指令相当嘈杂。这些指令涉及 Python 用于映射变量名的 Unicode 到字典查找操作。FunctionCounts 通常是一个与内容无关的容器,但对于获得可靠结果来说,这一点非常重要,值得作为一个例外。

返回类型

FunctionCounts

filter(filter_fn)[source][source]

仅保留将 filter_fn 应用于函数名后返回 True 的元素。

返回类型

FunctionCounts

transform(map_fn)[source][source]

map_fn 应用于所有函数名。

这可用于规范化函数名(例如,剥离文件路径中不相关的部分),通过将多个函数映射到同一个名称来合并条目(在这种情况下,计数会累加),等等。

返回类型

FunctionCounts

class torch.utils.benchmark.Compare(results)[source][source]

用于在格式化表格中显示多个测量结果的辅助类。

表格格式基于 torch.utils.benchmark.Timer 中提供的信息字段(descriptionlabelsub_labelnum_threads 等)。

表格可以使用 print() 直接打印,或转换为 str 类型。

有关如何使用此类的完整教程,请参阅:https://pytorch.ac.cn/tutorials/recipes/recipes/benchmark.html

参数

results (list[torch.utils.benchmark.utils.common.Measurement]) – 要显示的测量结果列表。

colorize(rowwise=False)[source][source]

为格式化表格着色。

默认按列着色。

extend_results(results)[source][source]

将结果追加到已存储的结果中。

所有添加的结果必须是 Measurement 的实例。

highlight_warnings()[source][source]

构建格式化表格时启用警告高亮。

print()[source][source]

打印格式化表格

trim_significant_figures()[source][source]

构建格式化表格时启用有效数字修剪。

文档

访问 PyTorch 的完整开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得问题解答

查看资源