Dynamo 深度探索¶
TorchDynamo(或简称为 Dynamo)是 torch.compile
中的跟踪器,它通常是那些疯狂回溯的罪魁祸首。然而,我们不能盲目地将这些错误归咎于 Dynamo。为了给用户提供所需的灵活性,Dynamo 承担了理解任何 Python 程序的艰巨任务。特别是,Dynamo 必须在内部实现 Python 编程语言的很大一部分!
在这篇文章中,我们将从头开始介绍 Dynamo 的内部设计。我们将讨论它提供的功能以及如何实现。读完这篇文章后,你将更好地理解当你使用 torch.compile
编译 PyTorch 程序时出现错误,或者编译成功但加速效果不如预期时,究竟是哪里出了问题。
Dynamo 温和入门¶
在深入了解所有实现细节之前,我们先来讨论 Dynamo 的作用。
Dynamo 是一个跟踪器。这意味着,给定一个函数及其输入,它会执行该函数并将一系列线性指令(无控制流)记录到一个图中。例如,考虑以下程序
import torch
@torch.compile
def mse(x, y):
z = (x - y) ** 2
return z.sum()
x = torch.randn(200)
y = torch.randn(200)
mse(x, y)
如果我们将此程序保存到文件 example.py
中并运行
TORCH_LOGS=graph_code python example.py
我们会看到 Dynamo 跟踪的输出
def forward(l_x_: torch.Tensor, l_y_: torch.Tensor):
# File: example.py:5, code: z = (x - y) ** 2
sub = l_x_ - l_y_
z = sub ** 2
# File: example.py:6, code: return z.sum()
sum_1 = z.sum()
return (sum_1,)
我们将这称为给定输入的函数图(或跟踪)。它通过 FX 图表示。我们可以简单地将 FX 图视为一个存储函数调用列表的容器。
我们首先应该注意到的是,图是 PyTorch 操作的线性序列。1 Dynamo 记录所有 PyTorch 操作并按顺序存储。例如,它将 z = (x - y) ** 2
拆分为两个组成操作:sub = l_x_ - l_y_
和 z = sub ** 2
。
当说跟踪是线性的时,意味着没有分支或任何控制流。为了验证这一点,考虑
import torch
@torch.compile
def fn(x, n):
y = x ** 2
if n >= 0:
return (n + 1) * y
else:
return y / n
x = torch.randn(200)
fn(x, 2)
当使用 TORCH_LOGS=graph_code
执行时,会返回
def forward(l_x_: torch.Tensor):
# File: example.py:5, code: y = x ** 2
y = l_x_ ** 2
# File: example.py:7, code: return (n + 1) * y
mul = 3 * y
return (mul,)
我们看到 Dynamo 完全从跟踪中移除了 if
语句,只记录了使用输入执行的操作。
因此,应该清楚的是,函数的跟踪取决于输入。特别是,这意味着跟踪不是在我们编写 @torch.compile
时生成的,而是在我们使用实际参数执行函数 fn(x, 2)
时生成的。
这里另一个值得注意的有趣之处是,Dynamo 移除了函数的第二个参数。相反,它将其视为常量并在图中记录了操作 n + 1
的结果。这是 Dynamo 的另一个特性:Dynamo 会将除整数以外的任何非张量值视为常量。现在来看看整数为何特别。
Dynamo 的最后一个决定性特性是它知道如何处理动态形状。符号形状是指 Dynamo 跟踪形状(更普遍地说,整数)的能力,而不是将其视为常量。这有助于避免重新编译,并在生产环境中部署适用于任何尺寸的通用模型。出现动态形状的主要例子是批处理大小,我们可能会使用固定批处理大小训练模型,但随后对任意批处理大小执行推理;或者处理文本或音频时遇到的变长序列。
我们可以通过多执行几次上面的示例来看到这一点
import torch
@torch.compile
def fn(x, n):
y = x ** 2
if n >= 0:
return (n + 1) * y
else:
return y / n
x = torch.randn(200)
fn(x, 2)
fn(x, 3)
fn(x, -2)
在这种情况下,TORCH_LOGS=graph_code
生成另外两个图
# Graph for n==2 omitted
def forward(self, l_x_: torch.Tensor, l_n_: torch.SymInt):
# File: a.py:5, code: y = x ** 2
y = l_x_ ** 2
# File: a.py:7, code: return (n + 1) * y
add = l_n_ + 1
mul = add * y
return (mul,)
def forward(self, l_x_: torch.Tensor, l_n_: torch.SymInt):
# File: a.py:5, code: y = x ** 2
y = l_x_ ** 2
# File: a.py:9, code: return y / n
truediv = y / l_n_
return (truediv,)
Dynamo 检测到第一个调用后一个整数改变了其值,并开始跟踪它。我们看到这些图是通用的,通过类型为 SymInt
的对象符号性地跟踪变量 n
。
如果在这些调用之后,我们调用 fn(x, 4)
,Dynamo 不会重新编译,而是重用已经跟踪的图。
总结一下: 1. Dynamo 是一个 Python 跟踪器 2. 给定一些输入,它返回一个包含已执行 PyTorch 函数的 FX 图 3. 如果检测到整数在调用之间发生了变化,它也可以跟踪整数 4. 它会特殊化除张量或标量之外的任何其他值
当然,Dynamo 还做了更多事情,比如判断何时需要重新跟踪、重写函数的字节码、实现图中断等…… 为了使介绍简短,我们将在后续内容中逐步讨论所有这些。
PEP 523: 为 CPython 添加一个帧评估 API¶
现在想象一下,我们接到了实现 Dynamo 的任务。我们甚至从哪里开始呢?对我们来说相当方便的是,PEP 523 随 Python 3.6 发布。这个 PEP 旨在允许第三方为 Python 创建 JIT 编译器。来看看如何实现。
关于 CPython 的说明:CPython 内部实现为一个栈机。Python 程序被编译成字节码,然后由该解释器执行。要了解更多关于这些字节码的信息,请参阅标准库中的 dis 模块。另请参阅开发者文档,了解 CPython 解释器的介绍。我们假设读者熟悉栈机的概念。
PEP 523 暴露了一个 API,用户可以添加一个自定义的按函数解释器。然后,CPython 将使用此解释器而不是其自己的解释器来执行该函数。为了能够执行函数,在进入时,CPython 会向自定义解释器提供以下信息: - 函数的字节码 - 函数参数的值(即局部变量)及其名称 - 全局变量的值及其名称 - 内置函数,例如 abs
或 print
总之,CPython 为用户的解释器提供了执行函数所需的所有信息。3
有了这个 API,我们可以通过实现一个运行代码并将执行过程中发生的所有 PyTorch 操作记录到图中的解释器来实现跟踪器。这正是 Dynamo 所做的。
Dynamo 使用这个 CPython API 来解析所有这些对象,并将它们打包到一个 Python 结构中。完成这些后……它就从 C 回到 Python 了。除了这部分与 CPython 通信的代码外,Dynamo 完全是用 Python 实现的。
应该清楚的是,装饰器 @torch.compile
的作用是安装必要的支架,以便在函数调用时将字节码、参数、全局变量等传递给 Dynamo。再次强调,@torch.compile
本身实际上不编译任何东西。
在 Python 中实现 CPython¶
所以,我们回到了 Python 世界。我们有了函数的字节码,以及执行它所需的所有上下文。特别是,我们抵达了 _convert_frame_assert。这是装饰器 torch.compile
返回的函数!我们从 _dynamo.optimize 到达此函数。装饰器 torch.compile
只是 _dynamo.optimize
的一个便捷 API。
在开始实现 Python 解释器之前,我们想定义一个 IR(中间表示)。特别是,我们想将所有局部变量和全局变量封装在我们自己的内部类中。这使我们能够更好地跟踪这些对象,并将 Dynamo 看来可以以相同方式处理的对象分组在一起。
内部类结构的父类是 VariableTracker
,它代表 Dynamo 理解的不同对象。例如,ListVariable
代表一个 list
对象,并在内部维护一个 VariableTrackers 列表。另一个 VariableTracker
的例子是 ConstantVariable。ConstantVariable 封装了所有被 Dynamo 视为常量的对象。我们还为需要特别关注的对象设置了特殊的子类,例如 TensorVariable。所有这些内部类都在 torch/_dynamo/variables 文件夹中定义。
Python 对象在 VariableBuilder._wrap 中被封装到其对应的 VariableTracker
类中。此函数只是一个非常长的 elif
链,它尝试将 Python 输入递归地模式匹配到适当的 VariableTracker
类型。
调试技巧。当我们从 dynamo 获得意外结果时,有时是由于构建器引起的。如果构建器的逻辑错误,有时 Dynamo 可能会将变量封装到不正确的 VariableTracker
类型中,这可能导致后续问题。查看错误中出现的 VariableTracker
类型以及遇到 Dynamo 错误时抛出异常的 VariableTracker
方法非常有用。特别是,有时我们会发现一个对象被跟踪为 UserDefinedObjectVariable
(这是 Dynamo 的包罗万象类),而它本应被跟踪为更具体的类型。在这些情况下,通常是 SourceBuilder.__call__
的逻辑问题。
调试技巧。当使用 TORCH_LOGS=dynamo
运行程序时,输出的其中一个信息是以下形式的行
TRACE LOAD_GLOBAL y [TorchInGraphFunctionVariable(<built-in method any>), TensorVariable()]
这是原始程序的字节码以及当时栈的状态。这对于查找对象未被正确跟踪到 VariableTracker
中的位置非常有用。
好的,我们现在有了跟踪器的 IR,现在我们只需要重新实现 CPython 的栈机。这在 symbolic_convert.py 中的 InstructorTranslatorBase 中实现。
InstructionTranslatorBase
大约有 200 个方法,实现了几乎所有的 Python 字节码。例如,我们可以看看 BUILD_LIST
的实现
def BUILD_LIST(self, inst):
items = self.popn(inst.argval)
self.push(ListVariable(items, mutation_type=ValueMutationNew()))
这是由 l = [2, 3, 4]
这样的结构生成的字节码。在这种情况下,由于有三个元素,生成的字节码是 BUILD_LIST 3
。这意味着我们弹出栈顶的 3
个元素,并将由这三个元素形成的新列表对象压入栈顶。
生成输出图¶
有了符号性执行 Python 代码的方法,我们就可以提取给定输入的程序在符号性执行过程中发生的 PyTorch 操作。这在 Dynamo 中通过 OutputGraph 对象实现。OutputGraph
对象绑定到一个 `InstructionTranslator 对象,它跟踪创建 Dynamo 将返回的 FX 图所需的所有数据。
FX 图的所有输入和中间元素都是 fx.Node
。在 Dynamo 中,fx.Node
被封装在 fx.Proxy
中。fx.Proxy
用于构建 FX 图。特别是,它们会将在其上执行的每个 PyTorch 操作记录到图中。你可以通过调用 create_proxy 创建一个新的操作添加到图中。然后,我们可以通过函数 wrap_fx_proxy 将其添加到图中。
一个图存储张量上的操作……以及符号整数上的操作。我们稍后会讨论符号整数,但首先我们讨论 Dynamo 如何解决一个相当重要的正确性问题。
使 Dynamo 健全:Guards¶
至此,我们有了一种完全忽略控制流来跟踪程序的方法。为此,我们重新实现了整个 CPython……如果这听起来有点过度,那是因为确实如此。torch.jit.trace 已经实现了这一点,而无需所有这些机制,那这是为什么呢?
torch.jit.trace
的问题在于,正如其文档中所警告的,它只适用于跟踪的程序不依赖于数据的情况。换句话说,只有程序本身是线性的时候它才起作用。这意味着编写程序时不能使用 if-else、for-while 循环、异常。更甚者,我们使用的任何库都不能使用任何控制流!总而言之,在一个像 Python 这样动态的语言中不使用控制流,实际上是一个巨大的限制。
JAX 通过始终重新跟踪并在重新跟踪后缓存图来解决这个问题。而 Dynamo 则使用 guards 来避免每次都重新跟踪整个程序。
一个 guard 是为了针对一组示例输入特殊化(specialize)一个帧而做出的假设(关于输入的布尔表达式)。只有当这些假设在新输入上仍然成立时,重用该图才有效。
例如,函数中的任何常量输入,比如字符串,都会安装一个 guard,表明该输入必须是类型 str
且等于我们传递的字符串。运行
import torch
@torch.compile
def fn(a, b):
return a * len(b)
fn(torch.arange(10), "Hello")
使用 TORCH_LOGS=guards
会打印出(以及其他 guards)
___check_type_id(L['b'], 94334122025024)
L['b'] == 'Hello'
这可以解读为“局部变量 b
应该具有特定类型(在本例中为 str
,由常量 9433...
表示)且其值应为 'Hello'
”。如果我们随后再次执行该函数并传递不同的参数
import torch
@torch.compile
def fn(a, b):
return a * len(b)
fn(torch.arange(10), "Hello")
fn(torch.arange(10), "Hi")
我们可以通过运行 TORCH_LOGS=recompiles
来查看失败的 guard
Recompiling function fn in script.py:3
triggered by the following guard failure(s):
- L['b'] == 'Hello'
Guards 在构建器中封装函数输入和程序执行期间累积。我们将在下一节展示更多 guards 的例子,但首先让我们讨论 sources。
一个 source 跟踪如何从进入当前帧时存在的原始局部变量或全局变量重构一个变量。特别是,它跟踪原始局部对象和全局对象以及它们包含的任何对象。在
def foo(x: Tensor, y: List[Tensor]):
a = x * y[0]
return a * x
x
和 y
的 source 是 LocalSource,而 y[0]
的 source 是 GetItemSource,后者内部存储一个 LocalSource
。另一方面,a
没有 source,因为它是一个只存在于 FX 图中的中间变量。
所有这些都定义在 torch/_dynamo/source.py 中。我们可以在下面的示例中看到 GetItemSource
生成的 guard
import torch
@torch.compile
def fn(x, l):
return x * len(l[0])
fn(torch.randn(8), ["Hi", "Hello"])
生成以下 guards
___check_type_id(L['l'], 94439025877664)
len(L['l']) == 2
___check_type_id(L['l'][0], 94439025840192)
L['l'][0] == 'Hi'
___check_type_id(L['l'][1], 94439025840192)
L['l'][1] == 'Hello'
这里,我们看到 GetItemSource
([0]
和 [1]
) 生成的代码,它封装了一个 LocalSource
(L['l']
)。
至此,有了 sources 和 guards,我们就能够实现一个缓存系统,避免每次都重新跟踪,从而避免重新编译。我们将在后续内容中更详细地讨论这个缓存系统。
细心的读者会注意到,这并没有解释为什么我们需要对 Python 解释器进行如此精细的控制,以至于不得不重新实现它。我们展示的 guards 示例依赖于输入对象,因此我们仍然可以在执行函数之前计算这些 guards。换句话说,我们可以在 torch.jit.trace
的基础上实现这个 guard 系统,并以少得多的精力获得相同的功能…… 这就需要引入符号形状了。
符号形状¶
我们在介绍中讨论的另一点是 Dynamo 知道如何跟踪整数。为了实现这一点,我们使用一个符号类 torch.SymInt,它表现得像一个 int
,但在输出 FX 图中记录了对其执行的所有操作。4 我们在介绍符号整数跟踪时已经在介绍中看到了这个类。
现在让我们讨论定义 Dynamo 中符号形状跟踪的三个属性,以及如何实现它们。
默认静态¶
Dynamo 假定每个整数,无论是输入还是张量的形状,默认都是静态的。换句话说,在函数的第一次执行中,不会追踪任何整数。只有当 Dynamo 检测到整数或形状值在执行过程中发生了变化时,它才会对其进行追踪,并生成一个针对该变量的通用图。
我们已经在介绍中使用整数看到了这种行为。现在让我们看一个使用张量形状的例子。
import torch
@torch.compile
def fn(a, b):
return a.shape[0] * a * b
fn(torch.randn(4, 3), torch.randn(4, 3))
fn(torch.randn(8, 3), torch.randn(8, 3))
使用 TORCH_LOGS=graph_code
运行此程序,我们看到这两个调用被追踪为
def forward(self, l_a_: torch.Tensor, l_b_: torch.Tensor):
mul = 4 * l_a_
mul_1 = mul * l_b_
return (mul_1,)
def forward(self, s0: torch.SymInt, l_a_: torch.Tensor, l_b_: torch.Tensor):
size = l_a_.size()
getitem = size[0]
mul = getitem * l_a_
mul_1 = mul * l_b_
return (mul_1,)
在第一个图中,形状被追踪为一个常量,但一旦它发生变化,它就会使用 SymInt
符号化地追踪它。通常,查看中间值形状的更简单方法是使用 TORCH_LOGS=graph_sizes
运行程序
TRACED GRAPH TENSOR SIZES
===== __compiled_fn_1 =====
l_a_: (s0, 3)
l_a_ (concrete): (8, 3)
l_b_: (s0, 3)
l_b_ (concrete): (8, 3)
mul: (s0, 3)
mul (concrete): (8, 3)
mul_1: (s0, 3)
mul_1 (concrete): (8, 3)
在这里我们可以看到,由于它由 s0
变量表示,因此两个张量参数的第一个维度是动态的。
我们可以通过运行 TORCH_LOGS=guards
来了解 Dynamo 如何实现这一点
# Guards first call
check_tensor(L['a'], torch.float32, device=None, requires_grad=False, size=[4, 3], stride=[3, 1])
check_tensor(L['b'], torch.float32, device=None, requires_grad=False, size=[4, 3], stride=[3, 1])
# Guards second call
check_tensor(L['a'], torch.float32, device=None, requires_grad=False, size=[None, 3], stride=[3, 1])
check_tensor(L['b'], torch.float32, device=None, requires_grad=False, size=[None, 3], stride=[3, 1])
L['b'].size()[0] == L['a'].size()[0]
2 <= L['a'].size()[0]
我们看到在第一次调用时,guards 检查张量是否具有固定的尺寸和步长。这些 guards 在第二次执行中失败,因此它会重新追踪。由于失败的是一个 int
guard,因此在第二次迭代中,它会对这个 int
进行符号化追踪,并在更通用的 kernel 上安装更通用的 guards。
编译性能提示。如果你知道某个维度的大小会变化,可以在调用 torch.compile
之前通过调用 torch._dynamo.mark_dynamic 将其标记为动态。这将避免第一次使用静态形状的编译。还有其他有用的实用函数,如 maybe_mark_dynamic
或 mark_static
。你还可以通过调用 torch.compile(dynamic=True)
来追踪所有整数和形状。这主要用于调试目的。
0、1 总是会被特殊化¶
无论我们是否将某个维度标记为动态,如果我们传入一个该维度为 0 或 1 的输入,Dynamo 都会将其追踪为非动态,并为其生成一个特定的图。这就是为什么在上面的例子中我们发现 guards 的形式是 2 <= L['a'].size()[0]
。
做出这个选择有几个原因。其中两个尤其重要:- 当且仅当张量的任一维度为零时,该张量为空。- 当且仅当张量的步长之一为一时,该张量才能是连续的。
此策略决定不适用于普通的 Python int;如果我们认为 Python int 应该动态编译,我们默认不会将其特殊化;相反,它是否被特殊化取决于其用法。
“鸭子”形状 (Duck shaping)¶
Dynamo 执行我们所谓的“鸭子”形状 (duck shaping)。如果在追踪时两个动态整数具有相同的值,我们将假定它们相等并进行守卫 (guard)。实际上,这意味着在上面的示例中,我们不是拥有两个符号 s0
、s1
,而是将它们统一为 s0
并设置守卫 L['b'].size()[0] == L['a'].size()[0]
。这使得能够在编译器内执行融合,同时能够生成足够通用的 kernel。
符号整数上的守卫 (Guards on symbolic ints)¶
我们现在在高层次上理解了符号形状是如何实现的以及它们具有的属性。那么,为什么符号形状迫使我们走上控制 CPython 解释器的棘手道路呢?考虑以下示例:
import torch
@torch.compile(dynamic=True)
def fn(a):
if a.shape[0] * 2 < 16:
return a
else:
return a + 1
fn(torch.randn(8))
此代码有一个形式为 2*L['a'].size()[0] >= 16
的守卫。这是一个在函数输入方面非平凡的守卫,但在程序执行过程中注册。更重要的是,我们直到看到依赖于 SymNodeVariable
参数的 if
语句条件时,才知道需要这个守卫。这些条件对于 torch.jit.trace
是不可见的,需要对 Python 代码进行深入分析。
调试技巧 使用 TORCH_LOGS=dynamo
运行此代码可以告诉我们这个守卫是在哪里添加的
eval 2*s0 >= 16 [guard added] at script.py:5 in fn (_dynamo/variables/tensor.py:812 in evaluate_expr)
在那里设置一个断点并查看回溯对于理解守卫来自何处非常有用。
使 Dynamo 完整:图中断 (Graph Breaks)¶
有了我们讨论过的所有工具,我们现在拥有一个能够追踪张量和整数上的 PyTorch 操作的追踪器,并且它具有一个缓存系统,知道何时可以重用之前追踪的图以及何时需要重新追踪。所有这些都能执行任意 Python 代码!
但这有一个小问题。“执行任意 Python 代码”的说法可能过于宽泛了。Dynamo 实现了 Python 的大部分功能,但它是否实现了更复杂的部分,比如协程 (coroutines) 或异步 (async)?它是否实现了整个 Python 标准库?NumPy 也有 Python API。torch.compile
是否也能理解 NumPy?还有 Django? 5
Python 的生态系统非常庞大,其中很大一部分是用 C++ 或 Rust 等性能更高的语言编写的,并且只暴露了 Python 绑定。Dynamo 无法追踪通过 C++ 实现的 Python 对象。当追踪器遇到它不理解的操作时,它能做什么?
机器学习追踪器处理这个问题通常的方式是告知用户它们在哪个操作上遇到了困难,并完全放弃追踪。这在 PyTorch 中会带来实际的可用性问题,因为 PyTorch 的用户习惯了它提供的灵活性。举一个现实世界的例子,doctr_det_predictor
模型使用了 NumPy 和 cv2
库来对模型结果进行后处理。
这是另一个访问 CPython 很有意义的地方。Dynamo 不会报错,而是可以让 CPython 运行那段有问题代码!为此,Dynamo 在追踪时生成一个包含有问题代码之前所有操作的图,以及一个包含有问题代码之后所有操作的图。6 然后,在运行时,它将委托给 CPython 执行第一个图,然后是有问题的代码,最后是第二个图。停止追踪并生成多个图的过程称为 图中断 (graph break)。
一个小小的坦白:我在整个介绍和前几节中都在撒谎。Dynamo 生成的不是一个图,而是 多个图!实际上,将图中断后重新开始追踪视为开始追踪一个新的函数。图中断后的新图将有自己的 guards、新的局部变量集等等。
要讨论如何实现图中断,我们需要首先回顾 Dynamo 如何与 CPython 交互。使用 PEP 523,CPython 允许用户使用自己的帧评估机制。我们之前没有讨论的是,CPython 也暴露了自己的帧评估供其他人使用。Dynamo 利用这一点,让快速的 CPython 解释器运行编译后的代码。对于一个没有图中断的函数,程序调用该函数两次且参数相同时的整个追踪/执行过程如下所示:
在第一次调用函数时
Dynamo 将函数追踪成一个 FX 图
FX 图由编译器 (Inductor) 编译成高效的底层代码……但这又是另一天的故事了
它重写函数的字节码,使其只需调用编译后的函数
它将这段新的字节码交给 CPython 并要求它运行 [此处]
在第二次调用函数时
这个过程本身看起来过于复杂。为什么生成新的字节码并要求 CPython 运行,而不是简单地创建一个到编译函数的 C++ 绑定并执行它呢?嗯,这种模式使我们能够实现图中断!由图中断生成的字节码具有以下结构:
执行第一个图的字节码
使栈状态与 CPython 执行第一个图后相同的字节码。它还会重放在此刻可见的局部或全局变量的任何修改
导致 Dynamo 图中断的字节码
执行第二个图的字节码
让我们通过一个简单的例子来看看
import torch
@torch.compile
def fn(a):
b = a + 2
print("Hi")
return b + a
fn(torch.randn(4))
使用 TORCH_LOGS=bytecode
运行此程序会显示初始字节码和修改后的字节码
MODIFIED BYTECODE fn script.py line 3
0 LOAD_GLOBAL 1 (__compiled_fn_0)
2 LOAD_FAST 0 (a)
4 CALL_FUNCTION 1
6 STORE_FAST 3 (graph_out_0)
8 LOAD_GLOBAL 0 (print)
10 LOAD_CONST 2 ('Hi')
12 LOAD_FAST 3 (graph_out_0)
14 LOAD_CONST 3 (0)
16 BINARY_SUBSCR
18 STORE_FAST 1 (b)
20 CALL_FUNCTION 1
22 LOAD_GLOBAL 2 (__resume_at_14_1)
24 ROT_TWO
26 LOAD_FAST 0 (a)
28 LOAD_FAST 1 (b)
30 CALL_FUNCTION 3
32 RETURN_VALUE
MODIFIED BYTECODE resume_in_fn script.py line 6
0 LOAD_GLOBAL 1 (__compiled_fn_2)
2 LOAD_FAST 2 (b)
4 LOAD_FAST 1 (a)
6 CALL_FUNCTION 2
8 UNPACK_SEQUENCE 1
10 RETURN_VALUE
我们可以看到修改后的字节码被分割成两个函数:原始函数 fn
,以及一个名为 resume_in_fn
的函数。第二个函数是 Dynamo 为了实现程序从图中断点开始执行而创建的函数。这通常被称为延续函数 (continuation function)。这个延续函数只需使用正确的参数调用第二个编译后的函数。初始函数的代码根据我们之前描述的策略进行了重写
L0-4. 调用编译后的函数 (
a + 2
)。L6. 将其结果存储在一个名为
graph_out_0
的局部变量中。graph_out_0
是一个元组L8-18. 使栈在图中断点保持其应有的状态
L20. 执行导致图中断的代码
L22-32. 调用编译后的延续函数 (
a + b
)
Dynamo 中栈的代码生成被委托给 VariableTracker
子类。Dynamo 中的每个 VariableTracker
对象都有一个 reconstruct 方法,该方法生成必要的字节码以在栈上创建它所代表的 Python 对象。
调试技巧。图中断会影响性能,因此最好避免它们。使用 TORCH_LOGS=graph_breaks
运行程序是找出程序发生了多少次图中断的好方法。它返回的信息是以 VariableTracker
对象的形式呈现的,因此上面的调试技巧有时也有助于弄清楚是什么导致了图中断。
结论¶
Dynamo 是一块复杂的软件。一旦你着手实现 CPython 解释器,你就知道这将是一段不寻常的旅程。话虽如此,我们希望这篇文章能帮助你揭开它的一些神秘面纱。
Dynamo(大部分)是用 Python 实现的。我们留下了许多讨论过的代码片段的链接。我们希望阅读这些代码片段,并搜索调用它们的地方,或在它们上面设置断点并查看调用栈,有助于理解其余的代码库。
当然,学习软件如何工作的最佳方法是扩展它。在这种情况下,最好的方法是查看 github 上的 Dynamo 未解决问题。其中许多只需要对代码进行很小的更改,一旦你找到需要进行这些更改的地方。
脚注¶
- 1
在文献中,这被称为有向无环图 (Directed Acyclical Graph, DAG)。
- 2
所有这些绑定代码都位于
torch/csrc/dynamo/eval_frame.c
中。- 3
在 CPython 术语中,所有这些对象的集合称为一个 frame。
- 4
还有
SymBool
和SymFloat
类。后一个在撰写本文时用得不多。- 5
有趣的是,它确实理解 NumPy 代码!看看这篇博客文章和文档。现在,这之所以可能,是因为我们使用 PyTorch 重新实现了 NumPy。不过,祝你将 Django 用 PyTorch 实现顺利……
- 6
假设只有一段问题代码。如果问题代码更多,Dynamo 可以将代码分割成所需数量的图。