作者:Kaichao You

depyf logo

我们很高兴向 PyTorch 生态系统推出一个新项目:depyf!该项目旨在帮助用户理解、学习和适应 torch.compile

动机

torch.compile 是 PyTorch 2.x 的基石,只需一行代码即可为训练和推理加速机器学习工作流程。仅仅包含 @torch.compile 就能显著提升代码性能。然而,找到 torch.compile 的最佳插入点并不容易,更不用说调整各种参数以实现最高效率的复杂性了。

torch.compile 堆栈的复杂性,包括 Dynamo、AOTAutograd、Inductor 等,带来了陡峭的学习曲线。这些组件对于深度学习性能优化至关重要,但如果没有扎实的基础,可能会令人望而生畏。

注意:有关 torch.compile 工作原理的入门示例,请参阅此逐步讲解

一个常用工具:TORCH_COMPILE_DEBUG

为了揭秘 torch.compile,常用的方法是利用 TORCH_COMPILE_DEBUG 环境变量。虽然它提供了更多信息,但解读其输出仍然是一项艰巨的任务。

例如,当我们有以下代码时

# test.py
import torch
from torch import _dynamo as torchdynamo
from typing import List

@torch.compile
def toy_example(a, b):
   x = a / (torch.abs(a) + 1)
   if b.sum() < 0:
       b = b * -1
   return x * b

def main():
   for _ in range(100):
       toy_example(torch.randn(10), torch.randn(10))

if __name__ == "__main__":
   main()

并使用 TORCH_COMPILE_DEBUG=1 python test.py 运行它时,我们会得到一个名为 torch_compile_debug/run_2024_02_05_23_02_45_552124-pid_9520 的目录,该目录下有这些文件

.
├── torchdynamo
│   └── debug.log
└── torchinductor
   ├── aot_model___0_debug.log
   ├── aot_model___10_debug.log
   ├── aot_model___11_debug.log
   ├── model__4_inference_10.1
   │   ├── fx_graph_readable.py
   │   ├── fx_graph_runnable.py
   │   ├── fx_graph_transformed.py
   │   ├── ir_post_fusion.txt
   │   ├── ir_pre_fusion.txt
   │   └── output_code.py
   ├── model__5_inference_11.2
   │   ├── fx_graph_readable.py
   │   ├── fx_graph_runnable.py
   │   ├── fx_graph_transformed.py
   │   ├── ir_post_fusion.txt
   │   ├── ir_pre_fusion.txt
   │   └── output_code.py
   └── model___9.0
       ├── fx_graph_readable.py
       ├── fx_graph_runnable.py
       ├── fx_graph_transformed.py
       ├── ir_post_fusion.txt
       ├── ir_pre_fusion.txt
       └── output_code.py

生成的文件和日志往往带来的问题比解答的还多,让开发者对数据中的含义和关系感到困惑。TORCH_COMPILE_DEBUG 常见的困惑包括

  • model__4_inference_10.1 是什么意思?
  • 我有一个函数,但在目录中有三个 model__xxx.py,它们之间有什么对应关系?
  • debug.log 中的那些 LOAD_GLOBAL 是什么?

更好的工具:depyf 助您脱困

让我们看看 depyf 如何帮助开发者解决上述挑战。要使用 depyf,只需执行 pip install depyf 或参照项目页面 https://github.com/thuml/depyf 安装最新版本,然后将主代码包裹在 with depyf.prepare_debug 中即可。

# test.py
import torch
from torch import _dynamo as torchdynamo
from typing import List

@torch.compile
def toy_example(a, b):
   x = a / (torch.abs(a) + 1)
   if b.sum() < 0:
       b = b * -1
   return x * b

def main():
   for _ in range(100):
       toy_example(torch.randn(10), torch.randn(10))

if __name__ == "__main__":
   import depyf
   with depyf.prepare_debug("depyf_debug_dir"):
       main()

执行 python test.py 后,depyf 将生成一个名为 depyf_debug_dir 的目录(这是 prepare_debug 函数的参数)。该目录下会有这些文件

.
├── __compiled_fn_0 AFTER POST GRAD 0.py
├── __compiled_fn_0 Captured Graph 0.py
├── __compiled_fn_0 Forward graph 0.py
├── __compiled_fn_0 kernel 0.py
├── __compiled_fn_3 AFTER POST GRAD 0.py
├── __compiled_fn_3 Captured Graph 0.py
├── __compiled_fn_3 Forward graph 0.py
├── __compiled_fn_3 kernel 0.py
├── __compiled_fn_4 AFTER POST GRAD 0.py
├── __compiled_fn_4 Captured Graph 0.py
├── __compiled_fn_4 Forward graph 0.py
├── __compiled_fn_4 kernel 0.py
├── __transformed_code_0_for_torch_dynamo_resume_in_toy_example_at_8.py
├── __transformed_code_0_for_toy_example.py
├── __transformed_code_1_for_torch_dynamo_resume_in_toy_example_at_8.py
└── full_code_for_toy_example_0.py

其中有两大显著优势

  1. 冗长难懂的 torchdynamo/debug.log 不见了。其内容经过清理,以人类可读的源代码形式展示在 full_code_for_xxx.py__transformed_code_{n}_for_xxx.py 中。值得注意的是,depyf 最繁琐困难的工作就是将 torchdynamo/debug.log 内部的字节码反编译成 Python 源代码,从而使开发者摆脱了 Python 令人望而生畏的内部细节。
  2. 函数名称和计算图之间的对应关系得到了保留。例如,在 __transformed_code_0_for_toy_example.py 中,我们可以看到一个名为 __compiled_fn_0 的函数,我们会立即知道其对应的计算图在 __compiled_fn_0_xxx.py 中,因为它们共享相同的 __compiled_fn_0 前缀名称。

full_code_for_xxx.py 开始,并跟随其中涉及的函数,用户将清楚地了解 torch.compile 对其代码做了什么。

还有一点:逐行调试能力

使用调试器逐行调试代码是理解代码工作原理的好方法。然而,在使用 TORCH_COMPILE_DEBUG 时,这些文件仅供用户参考信息,不能与用户关注的数据一起执行。

注意:“调试”在此处指检查和改进程序的过程,而非纠正有 Bug 的代码。

depyf 的一个突出特点是它能够促进对 torch.compile 进行逐行调试:它生成的所有文件都与 Python 解释器内部的运行时代码对象关联,我们可以在这些文件中设置断点。用法很简单,只需添加一个上下文管理器 with depyf.debug(),即可实现此功能

# test.py
import torch
from torch import _dynamo as torchdynamo
from typing import List

@torch.compile
def toy_example(a, b):
   x = a / (torch.abs(a) + 1)
   if b.sum() < 0:
       b = b * -1
   return x * b

def main():
   for _ in range(100):
       toy_example(torch.randn(10), torch.randn(10))

if __name__ == "__main__":
   import depyf
   with depyf.prepare_debug("depyf_debug_dir"):
       main()
   with depyf.debug():
       main()

只需要注意一点:调试 torch.compile 的工作流程与标准调试工作流程有所不同。使用 torch.compile 时,许多代码是动态生成的。因此,我们需要

  1. 启动程序
  2. 当程序退出 with depyf.prepare_debug("depyf_debug_dir") 后,代码将在 depyf_debug_dir 中可用。
  3. 当程序进入 with depyf.debug() 时,它会在内部自动设置一个断点,以便程序暂停。
  4. 导航到 depyf_debug_dir 设置断点。
  5. 继续运行代码,调试器将命中这些断点!

depyf screenshot

以下是截图示例。所有代码和 Tensor 变量都是“实时”的,我们可以检查任何变量,并逐行执行代码,就像我们日常的调试工作流程一样!唯一的区别是我们正在调试由 torch.compile 生成的代码,而不是人工编写的代码。

结论

torch.compile 是一个宝贵的工具,可以轻松加速 PyTorch 代码。对于希望深入了解 torch.compile 的用户,无论是为了充分发挥其潜力还是集成自定义操作,学习曲线可能会非常陡峭。depyf 旨在降低这一门槛,提供用户友好的体验,帮助理解、学习和适应 torch.compile

请务必探索 depyf 并亲身体验其优势!该项目是开源的,可在 https://github.com/thuml/depyf 轻松获取。通过 pip install depyf 安装非常简单。我们希望 depyf 能够提升每个人使用 torch.compile 的开发工作流程。