保护概述¶
从 UX 角度来看,TorchDynamo 非常易于使用。用户调用 torchdynamo.optimize 作为注释
@torchdynamo.optimize(my_compiler)
def fn_foo(bar):
其中一个完整示例如下所示
from typing import List
import torch
from torch import _dynamo as torchdynamo
def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
    print("my_compiler() called with FX graph:")
    gm.graph.print_tabular()
    return gm.forward  # return a python callable
@torchdynamo.optimize(my_compiler)
def toy_example(a, b):
    x = a / (torch.abs(a) + 1)
    if b.sum() < 0:
        b = b * -1
    return x * b
for _ in range(100):
    toy_example(torch.randn(10), torch.randn(10))
这允许 TorchDynamo 捕获解释的 Python 帧,获取任何和所有相关信息,并在任何可能的情况下加快速度。加速来自几个方面,并且可能相当依赖于提供的后端(上面的示例中的 my_compiler),但本节中重要的一个加速是缓存。缓存本身不是直接加速,但它是一种关键的启用功能,可以防止重新编译。我们用 dynamo 挖了一个洞,而缓存使我们能够摆脱困境。它使我们能够保持性能中立,同时启用后端 - 我们加速的真正来源。
即使提供了直通无操作后端
def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
    return gm.forward
我们也可以看到 TorchDynamo 即使在普通 Python 上也能加快 Python 执行速度,而不仅仅是 PyTorch。
缓存和保护概述¶
TorchDynamo 通过缓存由 TorchDynamo 转换的用户字节码来运行。当 TorchDynamo 收到一个帧进行评估时,它会检查帧中引用的对象是否以某些方式发生改变,如果没有,TorchDynamo 会读取先前转换的用户字节码来评估它。在本节中,我们将重点介绍如何识别帧中引用的对象是否发生改变。这是 TorchDynamo 中一项关键的功能,因为它驱动了整个失效生命周期。此功能称为保护。
在非常高的层面上,流程可以这样总结
- TorchDynamo 接收 Python 帧。 
- 它转换帧 (1),通过指令翻译传递它。 
- 对于在 (2) 中捕获的对象,TorchDynamo 创建跟踪对象,这些对象 - 在输出图中进行跟踪,这是 torch.fx.Tracer 的内部专业化 
- 保护 
 
- TorchDynamo 处理在 (3) 中创建的保护对象,将它们变成一个生成的 Python 函数 check_fn,该函数与一段代码相关联。 
- 每当我们随后遇到此代码时,都会评估 check_fn - 如果 check_fn 通过并评估为 True,TorchDynamo 会将缓存中的代码和此处遇到的代码识别为相同,并且可以安全使用。如果它失败并评估为 False,TorchDynamo 会将缓存中的代码识别为无效,并且可以通过重新编译或图中断将其丢弃,以支持新条目。 
Python 帧评估和 PEP 523¶
TorchDynamo 的功能基于 PEP 523。
TorchDynamo 通过使用 _PyInterpreterState_SetEvalFrameFunc 在 Python 上安装帧评估函数。TorchDynamo 有一个挂钩,在评估期间 Python 可以将控制权交还给我们。
我们安装的函数在 nopython=True 情况下是 convert_frame 或 convert_frame_assert,但现在先忽略该细微差别,我们来看一下 convert_frame_assert,因为 convert_frame 代理到它。
我们可以在 torch/_dynamo/convert_frame.py 中找到具有以下签名的函数
def  convert_frame_assert(compiler_fn: Callable, one_graph=True):
此函数包装 Python 使用帧调用 TorchDynamo 的入口点
def  _convert_frame_assert(frame: types.FrameType, cache_size: int):
此函数执行以下操作
- 检查它是否之前见过此 - code(请参阅:此处 的 f_code),如果见过则提前退出。
- 检查代码是否是受不支持的情况。 
- 检查 - cache_size(上面的第二个参数)是否超过配置中定义的限制- cache_size_limit。如果超过,该函数将删除帧并记录警告。这有助于避免帧的持续重新编译,因为它通常意味着帧以意外的方式变热,并且缓存它会产生不必要的开销,因为它在下一次遇到时很可能会被驱逐。
- 通过 - transform_code_object传递帧以及一个通过字节码转换创建- InstructionTranslator的函数。这里有一些关键的事情发生- 通过 - transform_code_object产生新代码。
- 通过 - InstructionTranslator生成名为- output的 FX 跟踪器。这可能会有点令人困惑,因为- InstructionTranslator不是 fx 跟踪器,但它存储在名为跟踪器的变量中,其输出是 fx 跟踪器。
- 该函数生成保护并将其存储在上面的 - output上。
- 该函数生成 - output_instructions并将其存储在上面的- output上。
- 该函数将新生成的转换代码映射到从帧中读取的初始代码。值得记住此映射,我们将在后面介绍保护失败的内容中引用它。 
 
- 该函数使用 4.1 中的转换代码和 4.3 中的保护生成 GuardedCode。 
现在我们已经了解了帧评估,让我们回顾一下 InstructionTranslator,看看它是如何将我们交给它的帧转换为 TorchDynamo 内部类型的。
InstructionTranslator¶
InstructionTranslator 做了很多!我们不会介绍它所做的一切的详细信息,但最重要的是,对于本文档,它会生成 symbolic_locals 的映射,该映射维护从帧的 f_locals 到 TorchDynamo 内部变量对象的映射(稍后会详细介绍)。symbolic_locals 通过遍历帧的本地变量来填充
self.symbolic_locals = collections.OrderedDict(
    (k, VariableBuilder(self, LocalSource(k))(f_locals[k]))
    for k in vars
    if k in f_locals
)
这里的重要组成部分是调用 VariableBuilder 中的调用。 VariableBuilder 的调用实现代理到一个名为 _wrap 的函数,该函数反过来既构造 VariableTracker 的实例,又对其调用 make_guards。稍后会详细介绍。
反过来,此映射至关重要,因为每个变量都有关联的保护,然后这些保护被传递给 self.output,即 OutputGraph 的实例,这是一个 fx 跟踪器,如上面第 4.2 节中所述。如果你还记得,这个存储在名为 output 的变量中的 OutputGraph 是我们在传递给 GuardedCode 之前存储保护的地方
InstructionTranslator 如何做到这一点?其核心是一个被泵送的循环,它驱动了一个函数 step。
step 仅是单个处理步骤,它获取一条指令并对其执行某些操作。
注意
这些是 TorchDynamo 的 transform_code_object 处理的真实指令,非常酷。
注意
本节故意跳过了 dis.get_instructions 的详细信息。
对于上面的示例,以下是一些 Instruction 的片段:
Instruction(opcode=124, opname='LOAD_FAST', arg=0, argval='b', offset=32, starts_line=8, is_jump_target=True, target=None)
Instruction(opcode=100, opname='LOAD_CONST', arg=3, argval=-1, offset=34, starts_line=None, is_jump_target=False, target=None)
Instruction(opcode=20, opname='BINARY_MULTIPLY', arg=None, argval=None, offset=36, starts_line=None, is_jump_target=False, target=None)
这是此函数的核心功能。看看 opname,然后看看 step 中的这个小片段;
if not hasattr(self, inst.opname):
    unimplemented(f"missing: {inst.opname}")
getattr(self, inst.opname)(inst)
如我们所见,该函数检查当前类,即 InstructionTranslator 是否具有与操作符名称(例如,LOAD_CONST)匹配的属性集。如果存在,则该函数调用它,并传入整个指令对象。如果不存在,则该函数将帧删除为未实现。
对于 LOAD_CONST 示例,我们可以看到我们确实支持它,并且定义相对简单
def LOAD_CONST(self, inst):
    self.push(ConstantVariable(value=inst.argval))
我们可以看到,此函数创建了类 ConstantVariable 的新实例,其中包含一个值,在我们的示例中为 -1,然后将其推入堆栈。
有几十种这样的方法 - 请参阅 symbolic_convert.py 以了解所有这些方法。通常,我们会尽可能多地实现与 Python 字节码指令匹配的方法。
在 step 下游的逻辑和调用 VariableBuilder 的逻辑中 - 我们现在有很多 VariableTracker,当然,我们已经讨论了很多关于创建保护措施的内容。让我们深入了解变量是什么,并更接近理解保护措施。
变量¶
ConstantVariable 是 VariableTracker 的一个实例。VariableTracker 表示跟踪的 Python 局部或堆栈值。
在 TorchDynamo 中表示对象时,VariableTracker 的作用正如其名,它跟踪给定的变量。它是一个非常灵活的类,但有几点需要注意
- 它通过以下方式管理底层对象周围的 - guard关系- make_guard
- replace_guards
- add_guard(s)
- propagate-- propagate(*vars: List[List["VariableTracker"]])- 可能是最重要的,因为它组合了传入的所有- VariableTracker实例的 guard。它访问 guard 并将这些 guard 组合到自身上。
 
- 它作为底层对象的代理,为 TorchDynamo 的其余部分实现方法,以获取有关被跟踪对象的的信息 - call_method
- call_function
- python_type
- as_proxy
- is/as_python_proxy
 
- 它存储类型为 - Source的变量- source,来自- torchdynamo/source.py。此源类型是一个相对独立的类,它帮助我们组织和记录原始源的来源,并帮助提供方便的方法来获取名称等内容,对我们来说重要的是生成 guard。
此类 (VariableTracker) 围绕子类化构建,介于完整的抽象基类和完全充实的类之间 - 它会留下许多引发 NotImplementedError 的方法 - 依赖于子类。请参阅 torchdynamo/variables/ 以了解所有子类以履行合约和自定义行为。
了解了这些知识后,我们可以看到 dis 中的指令 BUILD_TUPLE 的示例
BUILD_TUPLE(count)创建一个元组,消耗堆栈中的 count 个项目,并将结果元组推送到堆栈上。
在我们的案例中,我们的签名会略有不同,这是因为我们创建 Instruction 对象的方式,但其要点是一样的。我们不传入 count,而是传入一个带有少量额外簿记的对象,当然,我们处理将常规旧 Python 对象转换为 TorchDynamo 概念
def BUILD_TUPLE(self, inst):
    items = self.popn(inst.argval)
    options = VariableTracker.propagate(items)
    self.push(TupleVariable(items, **options))
以下为此代码的作用
- 该函数读取 - argval,在这种情况下,类似于等效指令的 pydoc 中的- counts。
- 函数 - popn弹出项目,在这种情况下,签名为- def popn(self, n: int) -> List[TensorVariable]:这暗示了一个底层契约 - 我们返回- TensorVariables。如果我们仔细查看- symbolic_convert.py和- InstructionTranslatorBase/- InstructionTranslator,我们会看到推送到我们的堆栈并从堆栈中弹出的唯一内容是- VariableTracker。
- 该函数调用 - VariableTracker.propagate。这会获取步骤 2 中从堆栈中弹出的每个项目的保护,并递归遍历它,将所有保护合并到- options中:- py return { "guards": guards, }
- 然后,该函数使用 - items和- options创建一个新的- VariableTracker、- TupleVariable实例。然后,这允许我们安装来自组成新- TupleVariable的- items的所有适当保护。
注意
第一个保护来自哪里?传播是一个很好的技术,但我们需要在传播之前创建一些东西。当 VariableBuilder 从 f_locals 创建 VariableTracker 实例时,它会调用 make_guards。这反过来又调用 source,让它创建保护。
完成所有这些后,字节码转换完成,我们离生成 GuardedCode 又近了一步。我们现在了解了局部变量如何变成 VariableTracker,如何处理指令,以及在何处调用保护程序以进行创建。在深入了解如何将代码和保护程序组合到 GuardedCode 对象中之前,我们需要深入研究一下上面的 make_guard 和 source.make_guard 调用。然后,我们可以了解在 VariableTracker 实例旁边和上面创建保护程序时发生了什么。
创建保护程序¶
保护程序只是类 Guard 的 Python 对象。让我们更详细地了解它们。
查看数据类的定义(因此,ctor 签名),我们看到它有一个名称、一个源和一个创建函数。
@dataclasses.dataclass
class Guard:
    name: str
    source: GuardSource
    create_fn: Callable
名称应该是变量的名称。
此处的源是一个枚举,指示保护程序属于哪种类型的源。
注意
不要与 Source 和 source.py 中的其他类型混淆,因为它们存储在 VariableTracker 上。
create_fn 提供了从简单数据类过渡到实际生成有效的 Python 代码的主要功能,用于调用以了解在调用之间是否发生了变化,以及我们是否可以安全地从代码缓存中读取。
获取保护程序实例的最常见代码路径是通过 VariableTracker 上的 make_guards。 make_guards -> source.make_guard -> return Guard(self.name(), self.guard_source(), fn)
或者,在具体示例中
...
elif istype(value, range):
    guards = self.make_guards(GuardBuilder.EQUALS_MATCH)
    return RangeVariable(value=value, guards=guards)
由于 source 是在此 VariableTracker 的构造时间设置的,因此这里只需要提供 fn,GuardBuilder.EQUALS_MATCH 到 create_fn 字段。
此 create_fn 必须是 GuardBuilder 上的方法。在我们的下一步中,原因变得显而易见。一旦为帧创建了所有保护程序,我们就继续使用 CheckFunctionManager 和 compile_check_fn。
在 convert_frame 函数生成 GuardedCode 之前,它需要使用 CheckFunctionManager 运行所有防护措施,以生成 check_fn,然后将该函数与代码一起传递给 GuardedCode。这是我们存储在缓存条目中的 check_fn,也是我们用来判断是否检索存储在旁边的代码的函数。作为参考,以下是该代码
static CacheEntry *create_cache_entry(CacheEntry *next,
                                      PyObject *guarded_code) {
  CacheEntry *e = (CacheEntry *)malloc(sizeof(CacheEntry));
  DEBUG_NULL_CHECK(e);
  e->check_fn = PyObject_GetAttrString(guarded_code, "check_fn");
  NULL_CHECK(e->check_fn);
  e->code = (PyCodeObject *)PyObject_GetAttrString(guarded_code, "code");
  NULL_CHECK(e->code);
  e->next = next;
  return e;
}
我们现在知道如何使用 check_fn 函数,以及谁创建了它以及它的组成部分,但我们还不知道如何创建它。如何将 Guard 对象列表变成我们稍后可以运行的函数?
首先,我们迭代这些防护措施
for guard in sorted(guards or [], key=Guard.sort_key):
    if not config.guard_nn_modules and guard.is_nn_module():
        continue
    guard.create(local_builder, global_builder)
调用 guard.create 将运行我们在 Guard 类上设置的 create_fn(不要将其与我们正在努力生成的 check_fn 混淆,它们的名称相似,因此可能会有些混乱)。在上面的示例中,我们的 create_fn 是 GuardBuilder.EQUALS_MATCH。因此,我们现在正在调用它,并传入 self,即防护措施本身。
签名为:def EQUALS_MATCH(self, guard: Guard):
在该函数内部,我们可以使用防护措施上的 name 来获取我们的原始对象,查询它的数据和类型信息,进而获取最重要的部分:附加代码。
最简单的情况下,EQUALS_MATCH 只追加一行代码:self.code.append(f"{ref} == {val!r}")。其中 ref 是变量的名称,val 是值。它可能会生成类似这样的代码
y == 2
这是一个基本示例。但如果我们追加一些其他类型的 GuardBuilder 函数,然后将它们全部与每个语句之间的 and 结合起来(正如我们所做的那样),我们可能会得到类似这样的结果
___guarded_code.valid and ___check_type_id(y, 94367738391392) and y == 2 and ___check_tensors(x)
以下是此代码执行的操作
- 检查 - .valid
- 类型 ID 检查 
- 值检查 
- 张量检查 
这成为我们 check_fn 代码的核心,而我们再次遇到此代码时,又会评估 check_fn。然后它将检查
- 此代码是否仍然有效? 
- 如果 (1), - y是否仍然具有- 94367738391392的类型?
- 如果 (2), - y是否仍然为 2?
- 如果 (3),让我们检查张量 - x是否以某些特定方式更改。
如果所有这些仍然为真,那么我们可以使用与此 check_fn 一起缓存的代码。
注意
有关此操作如何以及在何处发生的更深入的探讨,你可以阅读 static PyCodeObject *lookup(CacheEntry *e, PyObject *f_locals) { 的 _eval_frame.c。
如果不是,那么我们可以继续重新编译代码,并将其与此代码一起存储在缓存中,以及一个全新的 check_fn,再次在另一个后续帧中进行检查。
在 GuardBuilder 上还有许多其他此类函数,它们有时会合并成巨大的字符串,然后将这些字符串作为 Python 代码进行评估并存储到 check_fn 中。上面的示例说明了一个简单的情况。为了更好地理解此功能,请阅读 GuardBuilder 上的其他函数,或者更好的是,转储 compile_check_fn 中的 code 变量,以查看正在生成的内容,尤其是在更大、真实的模型上。
摘要¶
在本节中,我们回顾了
- .valid的作用以及围绕弱引用(以及可能很快成为 NN 模块无效)的无效化。
- C++ 侧保护函数( - ___check_type_id、- ___check_tensors等)如何操作。
- 保护失败时会发生什么。 
- 如果我们生成了无效的保护代码会发生什么。 
我们介绍了如何对封装在 TorchDynamo 上下文中的用户提供的代码进行追踪和内部跟踪,将其组织到 VariableTracker 的 Source 中,随后组织到 Guard 中,以及这些 Guard 如何在处理 Python 代码时指导缓存条目选择和失效。