Compiled Autograd:为 torch.compile
捕获更大的反向图¶
创建于:2024 年 10 月 09 日 | 最后更新:2024 年 10 月 23 日 | 最后验证:2024 年 10 月 09 日
作者: Simon Fan
Compiled Autograd 如何与
torch.compile
交互如何使用 Compiled Autograd API
如何使用
TORCH_LOGS
检查日志
PyTorch 2.4
通读 PyTorch 2.x 入门 的 TorchDynamo 和 AOTAutograd 部分
概述¶
Compiled Autograd 是 PyTorch 2.4 中引入的 torch.compile
扩展,允许捕获更大的反向图。
虽然 torch.compile
确实捕获了反向图,但它是部分捕获的。 AOTAutograd 组件提前捕获反向图,但存在一些限制
前向图中的断点会导致反向图中的断点
反向钩子未被捕获
Compiled Autograd 通过直接与 autograd 引擎集成来解决这些限制,使其能够在运行时捕获完整的反向图。 具有这两个特征的模型应尝试 Compiled Autograd,并可能观察到更好的性能。
但是,Compiled Autograd 引入了自身的限制
在反向传播开始时添加了运行时开销以进行缓存查找
由于更大的捕获范围,更容易在 dynamo 中重新编译和图断裂
注意
Compiled Autograd 正在积极开发中,尚未与所有现有的 PyTorch 功能兼容。 有关特定功能的最新状态,请参阅Compiled Autograd 登陆页面。
设置¶
在本教程中,我们将以这个简单的神经网络模型为例。 它接受一个 10 维输入向量,通过单个线性层处理,并输出另一个 10 维向量。
import torch
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(10, 10)
def forward(self, x):
return self.linear(x)
基本用法¶
在调用 torch.compile
API 之前,请确保将 torch._dynamo.config.compiled_autograd
设置为 True
model = Model()
x = torch.randn(10)
torch._dynamo.config.compiled_autograd = True
@torch.compile
def train(model, x):
loss = model(x).sum()
loss.backward()
train(model, x)
在上面的代码中,我们创建了一个 Model
类的实例,并使用 torch.randn(10)
生成一个随机的 10 维张量 x
。 我们定义了训练循环函数 train
并使用 @torch.compile 对其进行装饰以优化其执行。 当调用 train(model, x)
时
Python 解释器调用 Dynamo,因为此调用已使用
@torch.compile
进行装饰。Dynamo 拦截 Python 字节码,模拟其执行并将操作记录到图中。
AOTDispatcher
禁用钩子并调用 autograd 引擎来计算model.linear.weight
和model.linear.bias
的梯度,并将操作记录到图中。 使用torch.autograd.Function
,AOTDispatcher 重写train
的前向和后向实现。Inductor 生成一个函数,该函数对应于 AOTDispatcher 前向和后向的优化实现。
Dynamo 设置优化的函数,以便 Python 解释器接下来进行评估。
Python 解释器执行优化的函数,该函数执行
loss = model(x).sum()
。Python 解释器执行
loss.backward()
,调用 autograd 引擎,由于我们设置了torch._dynamo.config.compiled_autograd = True
,因此路由到 Compiled Autograd 引擎。Compiled Autograd 计算
model.linear.weight
和model.linear.bias
的梯度,并将操作记录到图中,包括它遇到的任何钩子。 在此过程中,它将记录先前由 AOTDispatcher 重写的反向传播。 然后,Compiled Autograd 生成一个新函数,该函数对应于loss.backward()
的完全跟踪实现,并在推理模式下使用torch.compile
执行它。相同的步骤递归地应用于 Compiled Autograd 图,但是这次 AOTDispatcher 将不需要对图进行分区。
检查 Compiled Autograd 日志¶
使用 TORCH_LOGS
环境变量运行脚本
要仅打印 Compiled Autograd 图,请使用
TORCH_LOGS="compiled_autograd" python example.py
要打印具有更多张量元数据和重新编译原因的图,以性能为代价,请使用
TORCH_LOGS="compiled_autograd_verbose" python example.py
重新运行上面的代码片段,Compiled Autograd 图现在应该记录到 stderr
。 某些图节点的名称将以 aot0_
为前缀,这些节点对应于先前在 AOTAutograd 反向图 0 中提前编译的节点,例如,aot0_view_2
对应于 id=0 的 AOT 反向图的 view_2
。
在下图中,红色框封装了在没有 Compiled Autograd 的情况下由 torch.compile
捕获的 AOT 反向图。

注意
这是我们将调用 torch.compile
的图,不是优化的图。 Compiled Autograd 本质上生成一些未优化的 Python 代码来表示整个 C++ autograd 执行。
使用不同的标志编译前向和后向传递¶
您可以为两次编译使用不同的编译器配置,例如,即使前向传播中存在图断裂,后向传播也可能是全图。
def train(model, x):
model = torch.compile(model)
loss = model(x).sum()
torch._dynamo.config.compiled_autograd = True
torch.compile(lambda: loss.backward(), fullgraph=True)()
或者,您可以使用上下文管理器,它将应用于其范围内的所有 autograd 调用。
def train(model, x):
model = torch.compile(model)
loss = model(x).sum()
with torch._dynamo.compiled_autograd.enable(torch.compile(fullgraph=True)):
loss.backward()
Compiled Autograd 解决了 AOTAutograd 的某些限制¶
前向传播中的图断裂不再一定导致后向传播中的图断裂
@torch.compile(backend="aot_eager")
def fn(x):
# 1st graph
temp = x + 10
torch._dynamo.graph_break()
# 2nd graph
temp = temp + 10
torch._dynamo.graph_break()
# 3rd graph
return temp.sum()
x = torch.randn(10, 10, requires_grad=True)
torch._dynamo.utils.counters.clear()
loss = fn(x)
# 1. base torch.compile
loss.backward(retain_graph=True)
assert(torch._dynamo.utils.counters["stats"]["unique_graphs"] == 3)
torch._dynamo.utils.counters.clear()
# 2. torch.compile with compiled autograd
with torch._dynamo.compiled_autograd.enable(torch.compile(backend="aot_eager")):
loss.backward()
# single graph for the backward
assert(torch._dynamo.utils.counters["stats"]["unique_graphs"] == 1)
在第一个 torch.compile
案例中,我们看到由于编译函数 fn
中的 2 个图断裂,生成了 3 个反向图。 而在第二个带有 compiled autograd 的 torch.compile
案例中,我们看到尽管存在图断裂,但仍跟踪了完整的反向图。
注意
当跟踪 Compiled Autograd 捕获的反向钩子时,Dynamo 仍然可能发生图断裂。
现在可以捕获反向钩子
@torch.compile(backend="aot_eager")
def fn(x):
return x.sum()
x = torch.randn(10, 10, requires_grad=True)
x.register_hook(lambda grad: grad+10)
loss = fn(x)
with torch._dynamo.compiled_autograd.enable(torch.compile(backend="aot_eager")):
loss.backward()
图中应该有一个 call_hook
节点,dynamo 稍后会将其内联到以下内容中

Compiled Autograd 的常见重新编译原因¶
由于损失值的 autograd 结构发生变化
torch._dynamo.config.compiled_autograd = True
x = torch.randn(10, requires_grad=True)
for op in [torch.add, torch.sub, torch.mul, torch.div]:
loss = op(x, x).sum()
torch.compile(lambda: loss.backward(), backend="eager")()
在上面的示例中,我们在每次迭代时调用不同的运算符,导致 loss
每次都跟踪不同的 autograd 历史记录。 您应该看到一些重新编译消息:由于新的 autograd 节点而导致的缓存未命中。

由于张量改变形状
torch._dynamo.config.compiled_autograd = True
for i in [10, 100, 10]:
x = torch.randn(i, i, requires_grad=True)
loss = x.sum()
torch.compile(lambda: loss.backward(), backend="eager")()
在上面的示例中,x
更改形状,并且 compiled autograd 将在第一次更改后将 x
标记为动态形状张量。 您应该看到重新编译消息:由于形状更改而导致的缓存未命中。

结论¶
在本教程中,我们介绍了带有 compiled autograd 的 torch.compile
的高层生态系统、compiled autograd 的基础知识以及一些常见的重新编译原因。 请继续关注 dev-discuss 上的深入探讨。