保护概述¶
从 UX 角度来看,TorchDynamo 非常易于使用。用户调用 torchdynamo.optimize
作为注释
@torchdynamo.optimize(my_compiler)
def fn_foo(bar):
其中一个完整示例如下所示
from typing import List
import torch
from torch import _dynamo as torchdynamo
def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
print("my_compiler() called with FX graph:")
gm.graph.print_tabular()
return gm.forward # return a python callable
@torchdynamo.optimize(my_compiler)
def toy_example(a, b):
x = a / (torch.abs(a) + 1)
if b.sum() < 0:
b = b * -1
return x * b
for _ in range(100):
toy_example(torch.randn(10), torch.randn(10))
这允许 TorchDynamo 捕获解释的 Python 帧,获取任何和所有相关信息,并在任何可能的情况下加快速度。加速来自几个方面,并且可能相当依赖于提供的后端(上面的示例中的 my_compiler),但本节中重要的一个加速是缓存。缓存本身不是直接加速,但它是一种关键的启用功能,可以防止重新编译。我们用 dynamo 挖了一个洞,而缓存使我们能够摆脱困境。它使我们能够保持性能中立,同时启用后端 - 我们加速的真正来源。
即使提供了直通无操作后端
def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
return gm.forward
我们也可以看到 TorchDynamo 即使在普通 Python 上也能加快 Python 执行速度,而不仅仅是 PyTorch。
缓存和保护概述¶
TorchDynamo 通过缓存由 TorchDynamo 转换的用户字节码来运行。当 TorchDynamo 收到一个帧进行评估时,它会检查帧中引用的对象是否以某些方式发生改变,如果没有,TorchDynamo 会读取先前转换的用户字节码来评估它。在本节中,我们将重点介绍如何识别帧中引用的对象是否发生改变。这是 TorchDynamo 中一项关键的功能,因为它驱动了整个失效生命周期。此功能称为保护。
在非常高的层面上,流程可以这样总结
TorchDynamo 接收 Python 帧。
它转换帧 (1),通过指令翻译传递它。
对于在 (2) 中捕获的对象,TorchDynamo 创建跟踪对象,这些对象
在输出图中进行跟踪,这是 torch.fx.Tracer 的内部专业化
保护
TorchDynamo 处理在 (3) 中创建的保护对象,将它们变成一个生成的 Python 函数 check_fn,该函数与一段代码相关联。
每当我们随后遇到此代码时,都会评估 check_fn - 如果 check_fn 通过并评估为 True,TorchDynamo 会将缓存中的代码和此处遇到的代码识别为相同,并且可以安全使用。如果它失败并评估为 False,TorchDynamo 会将缓存中的代码识别为无效,并且可以通过重新编译或图中断将其丢弃,以支持新条目。
Python 帧评估和 PEP 523¶
TorchDynamo 的功能基于 PEP 523。
TorchDynamo 通过使用 _PyInterpreterState_SetEvalFrameFunc 在 Python 上安装帧评估函数。TorchDynamo 有一个挂钩,在评估期间 Python 可以将控制权交还给我们。
我们安装的函数在 nopython=True
情况下是 convert_frame
或 convert_frame_assert
,但现在先忽略该细微差别,我们来看一下 convert_frame_assert
,因为 convert_frame
代理到它。
我们可以在 torch/_dynamo/convert_frame.py
中找到具有以下签名的函数
def convert_frame_assert(compiler_fn: Callable, one_graph=True):
此函数包装 Python 使用帧调用 TorchDynamo 的入口点
def _convert_frame_assert(frame: types.FrameType, cache_size: int):
此函数执行以下操作
检查它是否之前见过此
code
(请参阅:此处 的 f_code),如果见过则提前退出。检查代码是否是受不支持的情况。
检查
cache_size
(上面的第二个参数)是否超过配置中定义的限制cache_size_limit
。如果超过,该函数将删除帧并记录警告。这有助于避免帧的持续重新编译,因为它通常意味着帧以意外的方式变热,并且缓存它会产生不必要的开销,因为它在下一次遇到时很可能会被驱逐。通过
transform_code_object
传递帧以及一个通过字节码转换创建InstructionTranslator
的函数。这里有一些关键的事情发生通过
transform_code_object
产生新代码。通过
InstructionTranslator
生成名为output
的 FX 跟踪器。这可能会有点令人困惑,因为InstructionTranslator
不是 fx 跟踪器,但它存储在名为跟踪器的变量中,其输出是 fx 跟踪器。该函数生成保护并将其存储在上面的
output
上。该函数生成
output_instructions
并将其存储在上面的output
上。该函数将新生成的转换代码映射到从帧中读取的初始代码。值得记住此映射,我们将在后面介绍保护失败的内容中引用它。
该函数使用 4.1 中的转换代码和 4.3 中的保护生成 GuardedCode。
现在我们已经了解了帧评估,让我们回顾一下 InstructionTranslator
,看看它是如何将我们交给它的帧转换为 TorchDynamo 内部类型的。
InstructionTranslator¶
InstructionTranslator 做了很多!我们不会介绍它所做的一切的详细信息,但最重要的是,对于本文档,它会生成 symbolic_locals
的映射,该映射维护从帧的 f_locals
到 TorchDynamo 内部变量对象的映射(稍后会详细介绍)。symbolic_locals
通过遍历帧的本地变量来填充
self.symbolic_locals = collections.OrderedDict(
(k, VariableBuilder(self, LocalSource(k))(f_locals[k]))
for k in vars
if k in f_locals
)
这里的重要组成部分是调用 VariableBuilder
中的调用。 VariableBuilder
的调用实现代理到一个名为 _wrap
的函数,该函数反过来既构造 VariableTracker
的实例,又对其调用 make_guards
。稍后会详细介绍。
反过来,此映射至关重要,因为每个变量都有关联的保护,然后这些保护被传递给 self.output
,即 OutputGraph
的实例,这是一个 fx 跟踪器,如上面第 4.2 节中所述。如果你还记得,这个存储在名为 output
的变量中的 OutputGraph
是我们在传递给 GuardedCode
之前存储保护的地方
InstructionTranslator
如何做到这一点?其核心是一个被泵送的循环,它驱动了一个函数 step
。
step
仅是单个处理步骤,它获取一条指令并对其执行某些操作。
注意
这些是 TorchDynamo 的 transform_code_object
处理的真实指令,非常酷。
注意
本节故意跳过了 dis.get_instructions 的详细信息。
对于上面的示例,以下是一些 Instruction
的片段:
Instruction(opcode=124, opname='LOAD_FAST', arg=0, argval='b', offset=32, starts_line=8, is_jump_target=True, target=None)
Instruction(opcode=100, opname='LOAD_CONST', arg=3, argval=-1, offset=34, starts_line=None, is_jump_target=False, target=None)
Instruction(opcode=20, opname='BINARY_MULTIPLY', arg=None, argval=None, offset=36, starts_line=None, is_jump_target=False, target=None)
这是此函数的核心功能。看看 opname
,然后看看 step
中的这个小片段;
if not hasattr(self, inst.opname):
unimplemented(f"missing: {inst.opname}")
getattr(self, inst.opname)(inst)
如我们所见,该函数检查当前类,即 InstructionTranslator
是否具有与操作符名称(例如,LOAD_CONST
)匹配的属性集。如果存在,则该函数调用它,并传入整个指令对象。如果不存在,则该函数将帧删除为未实现。
对于 LOAD_CONST
示例,我们可以看到我们确实支持它,并且定义相对简单
def LOAD_CONST(self, inst):
self.push(ConstantVariable(value=inst.argval))
我们可以看到,此函数创建了类 ConstantVariable
的新实例,其中包含一个值,在我们的示例中为 -1,然后将其推入堆栈。
有几十种这样的方法 - 请参阅 symbolic_convert.py
以了解所有这些方法。通常,我们会尽可能多地实现与 Python 字节码指令匹配的方法。
在 step
下游的逻辑和调用 VariableBuilder
的逻辑中 - 我们现在有很多 VariableTracker
,当然,我们已经讨论了很多关于创建保护措施的内容。让我们深入了解变量是什么,并更接近理解保护措施。
变量¶
ConstantVariable
是 VariableTracker
的一个实例。VariableTracker
表示跟踪的 Python 局部或堆栈值。
在 TorchDynamo 中表示对象时,VariableTracker
的作用正如其名,它跟踪给定的变量。它是一个非常灵活的类,但有几点需要注意
它通过以下方式管理底层对象周围的
guard
关系make_guard
replace_guards
add_guard(s)
propagate
-propagate(*vars: List[List["VariableTracker"]])
- 可能是最重要的,因为它组合了传入的所有VariableTracker
实例的 guard。它访问 guard 并将这些 guard 组合到自身上。
它作为底层对象的代理,为 TorchDynamo 的其余部分实现方法,以获取有关被跟踪对象的的信息
call_method
call_function
python_type
as_proxy
is/as_python_proxy
它存储类型为
Source
的变量source
,来自torchdynamo/source.py
。此源类型是一个相对独立的类,它帮助我们组织和记录原始源的来源,并帮助提供方便的方法来获取名称等内容,对我们来说重要的是生成 guard。
此类 (VariableTracker
) 围绕子类化构建,介于完整的抽象基类和完全充实的类之间 - 它会留下许多引发 NotImplementedError
的方法 - 依赖于子类。请参阅 torchdynamo/variables/
以了解所有子类以履行合约和自定义行为。
了解了这些知识后,我们可以看到 dis
中的指令 BUILD_TUPLE
的示例
BUILD_TUPLE(count)
创建一个元组,消耗堆栈中的 count 个项目,并将结果元组推送到堆栈上。
在我们的案例中,我们的签名会略有不同,这是因为我们创建 Instruction
对象的方式,但其要点是一样的。我们不传入 count
,而是传入一个带有少量额外簿记的对象,当然,我们处理将常规旧 Python 对象转换为 TorchDynamo 概念
def BUILD_TUPLE(self, inst):
items = self.popn(inst.argval)
options = VariableTracker.propagate(items)
self.push(TupleVariable(items, **options))
以下为此代码的作用
该函数读取
argval
,在这种情况下,类似于等效指令的 pydoc 中的counts
。函数
popn
弹出项目,在这种情况下,签名为def popn(self, n: int) -> List[TensorVariable]:
这暗示了一个底层契约 - 我们返回TensorVariables
。如果我们仔细查看symbolic_convert.py
和InstructionTranslatorBase
/InstructionTranslator
,我们会看到推送到我们的堆栈并从堆栈中弹出的唯一内容是VariableTracker
。
该函数调用
VariableTracker.propagate
。这会获取步骤 2 中从堆栈中弹出的每个项目的保护,并递归遍历它,将所有保护合并到options
中:py return { "guards": guards, }
然后,该函数使用
items
和options
创建一个新的VariableTracker
、TupleVariable
实例。然后,这允许我们安装来自组成新TupleVariable
的items
的所有适当保护。
注意
第一个保护来自哪里?传播是一个很好的技术,但我们需要在传播之前创建一些东西。当 VariableBuilder
从 f_locals
创建 VariableTracker
实例时,它会调用 make_guards
。这反过来又调用 source
,让它创建保护。
完成所有这些后,字节码转换完成,我们离生成 GuardedCode
又近了一步。我们现在了解了局部变量如何变成 VariableTracker
,如何处理指令,以及在何处调用保护程序以进行创建。在深入了解如何将代码和保护程序组合到 GuardedCode 对象中之前,我们需要深入研究一下上面的 make_guard
和 source.make_guard
调用。然后,我们可以了解在 VariableTracker
实例旁边和上面创建保护程序时发生了什么。
创建保护程序¶
保护程序只是类 Guard
的 Python 对象。让我们更详细地了解它们。
查看数据类的定义(因此,ctor 签名),我们看到它有一个名称、一个源和一个创建函数。
@dataclasses.dataclass
class Guard:
name: str
source: GuardSource
create_fn: Callable
名称应该是变量的名称。
此处的源是一个枚举,指示保护程序属于哪种类型的源。
注意
不要与 Source
和 source.py
中的其他类型混淆,因为它们存储在 VariableTracker
上。
create_fn
提供了从简单数据类过渡到实际生成有效的 Python 代码的主要功能,用于调用以了解在调用之间是否发生了变化,以及我们是否可以安全地从代码缓存中读取。
获取保护程序实例的最常见代码路径是通过 VariableTracker
上的 make_guards
。 make_guards
-> source.make_guard
-> return Guard(self.name(), self.guard_source(), fn)
或者,在具体示例中
...
elif istype(value, range):
guards = self.make_guards(GuardBuilder.EQUALS_MATCH)
return RangeVariable(value=value, guards=guards)
由于 source
是在此 VariableTracker
的构造时间设置的,因此这里只需要提供 fn
,GuardBuilder.EQUALS_MATCH
到 create_fn
字段。
此 create_fn
必须是 GuardBuilder
上的方法。在我们的下一步中,原因变得显而易见。一旦为帧创建了所有保护程序,我们就继续使用 CheckFunctionManager
和 compile_check_fn
。
在 convert_frame
函数生成 GuardedCode
之前,它需要使用 CheckFunctionManager
运行所有防护措施,以生成 check_fn
,然后将该函数与代码一起传递给 GuardedCode
。这是我们存储在缓存条目中的 check_fn
,也是我们用来判断是否检索存储在旁边的代码的函数。作为参考,以下是该代码
static CacheEntry *create_cache_entry(CacheEntry *next,
PyObject *guarded_code) {
CacheEntry *e = (CacheEntry *)malloc(sizeof(CacheEntry));
DEBUG_NULL_CHECK(e);
e->check_fn = PyObject_GetAttrString(guarded_code, "check_fn");
NULL_CHECK(e->check_fn);
e->code = (PyCodeObject *)PyObject_GetAttrString(guarded_code, "code");
NULL_CHECK(e->code);
e->next = next;
return e;
}
我们现在知道如何使用 check_fn
函数,以及谁创建了它以及它的组成部分,但我们还不知道如何创建它。如何将 Guard
对象列表变成我们稍后可以运行的函数?
首先,我们迭代这些防护措施
for guard in sorted(guards or [], key=Guard.sort_key):
if not config.guard_nn_modules and guard.is_nn_module():
continue
guard.create(local_builder, global_builder)
调用 guard.create
将运行我们在 Guard
类上设置的 create_fn
(不要将其与我们正在努力生成的 check_fn
混淆,它们的名称相似,因此可能会有些混乱)。在上面的示例中,我们的 create_fn
是 GuardBuilder.EQUALS_MATCH
。因此,我们现在正在调用它,并传入 self
,即防护措施本身。
签名为:def EQUALS_MATCH(self, guard: Guard):
在该函数内部,我们可以使用防护措施上的 name
来获取我们的原始对象,查询它的数据和类型信息,进而获取最重要的部分:附加代码。
最简单的情况下,EQUALS_MATCH
只追加一行代码:self.code.append(f"{ref} == {val!r}")
。其中 ref
是变量的名称,val
是值。它可能会生成类似这样的代码
y == 2
这是一个基本示例。但如果我们追加一些其他类型的 GuardBuilder
函数,然后将它们全部与每个语句之间的 and
结合起来(正如我们所做的那样),我们可能会得到类似这样的结果
___guarded_code.valid and ___check_type_id(y, 94367738391392) and y == 2 and ___check_tensors(x)
以下是此代码执行的操作
检查
.valid
类型 ID 检查
值检查
张量检查
这成为我们 check_fn
代码的核心,而我们再次遇到此代码时,又会评估 check_fn
。然后它将检查
此代码是否仍然有效?
如果 (1),
y
是否仍然具有94367738391392
的类型?如果 (2),
y
是否仍然为 2?如果 (3),让我们检查张量
x
是否以某些特定方式更改。
如果所有这些仍然为真,那么我们可以使用与此 check_fn
一起缓存的代码。
注意
有关此操作如何以及在何处发生的更深入的探讨,你可以阅读 static PyCodeObject *lookup(CacheEntry *e, PyObject *f_locals) {
的 _eval_frame.c
。
如果不是,那么我们可以继续重新编译代码,并将其与此代码一起存储在缓存中,以及一个全新的 check_fn
,再次在另一个后续帧中进行检查。
在 GuardBuilder
上还有许多其他此类函数,它们有时会合并成巨大的字符串,然后将这些字符串作为 Python 代码进行评估并存储到 check_fn
中。上面的示例说明了一个简单的情况。为了更好地理解此功能,请阅读 GuardBuilder
上的其他函数,或者更好的是,转储 compile_check_fn
中的 code
变量,以查看正在生成的内容,尤其是在更大、真实的模型上。
摘要¶
在本节中,我们回顾了
.valid
的作用以及围绕弱引用(以及可能很快成为 NN 模块无效)的无效化。C++ 侧保护函数(
___check_type_id
、___check_tensors
等)如何操作。保护失败时会发生什么。
如果我们生成了无效的保护代码会发生什么。
我们介绍了如何对封装在 TorchDynamo 上下文中的用户提供的代码进行追踪和内部跟踪,将其组织到 VariableTracker
的 Source
中,随后组织到 Guard
中,以及这些 Guard
如何在处理 Python 代码时指导缓存条目选择和失效。