• 教程 >
  • 编译自动微分:为 torch.compile 捕获更大的反向图
快捷方式

编译自动微分:为 torch.compile 捕获更大的反向图

作者: Simon Fan

您将学到什么
  • 编译自动微分如何与 torch.compile 交互

  • 如何使用编译自动微分 API

  • 如何使用 TORCH_LOGS 检查日志

先决条件

概述

编译自动微分是 PyTorch 2.4 中引入的 torch.compile 扩展,它允许捕获更大的反向图。

虽然 torch.compile 会捕获反向图,但它会部分捕获。AOTAutograd 组件会在提前时间捕获反向图,但存在某些限制

  • 前向图中的图断裂会导致反向图中的图断裂

  • 反向钩子 不会被捕获

编译自动微分通过直接与自动微分引擎集成来解决这些限制,从而使其能够在运行时捕获完整的反向图。具有这两个特征的模型应该尝试编译自动微分,并且可能会观察到更好的性能。

但是,编译自动微分也引入了自己的限制

  • 在反向传播开始时增加了运行时开销以进行缓存查找

  • 由于捕获范围更大,在 Dynamo 中更容易出现重新编译和图形中断

注意

编译后的 Autograd 正在积极开发中,尚未与所有现有的 PyTorch 功能兼容。有关特定功能的最新状态,请参阅 编译后的 Autograd 着陆页.

设置

在本教程中,我们将以这个简单的神经网络模型作为示例。它接收一个 10 维的输入向量,通过一个线性层进行处理,并输出另一个 10 维的向量。

import torch

class Model(torch.nn.Module):
   def __init__(self):
      super().__init__()
      self.linear = torch.nn.Linear(10, 10)

   def forward(self, x):
      return self.linear(x)

基本用法

在调用 torch.compile API 之前,请确保将 torch._dynamo.config.compiled_autograd 设置为 True

model = Model()
x = torch.randn(10)

torch._dynamo.config.compiled_autograd = True
@torch.compile
def train(model, x):
   loss = model(x).sum()
   loss.backward()

train(model, x)

在上面的代码中,我们创建了 Model 类的实例,并使用 torch.randn(10) 生成了一个随机的 10 维张量 x。我们定义了训练循环函数 train,并使用 @torch.compile 对其进行装饰以优化其执行。当调用 train(model, x)

  • Python 解释器调用 Dynamo,因为此调用已使用 @torch.compile 进行装饰。

  • Dynamo 拦截 Python 字节码,模拟其执行,并将操作记录到图形中。

  • AOTDispatcher 禁用钩子并调用 autograd 引擎为 model.linear.weightmodel.linear.bias 计算梯度,并将操作记录到图形中。使用 torch.autograd.Function,AOTDispatcher 重写了 train 的前向和后向实现。

  • Inductor 生成一个对应于 AOTDispatcher 前向和后向的优化实现的函数。

  • Dynamo 将优化后的函数设置为 Python 解释器下一步要执行的函数。

  • Python 解释器执行优化后的函数,该函数执行 loss = model(x).sum()

  • Python 解释器执行 loss.backward(),调用 autograd 引擎,该引擎路由到编译后的 Autograd 引擎,因为我们已将 torch._dynamo.config.compiled_autograd = True

  • 编译后的 Autograd 为 model.linear.weightmodel.linear.bias 计算梯度,并将操作记录到图形中,包括它遇到的任何钩子。在此过程中,它将记录先前由 AOTDispatcher 重写的后向。然后,编译后的 Autograd 生成一个新的函数,该函数对应于 loss.backward() 的完全跟踪实现,并使用 torch.compile 在推理模式下执行它。

  • 相同的步骤递归地应用于编译后的 Autograd 图形,但这次 AOTDispatcher 将不需要对图形进行分区。

检查编译后的 Autograd 日志

使用 TORCH_LOGS 环境变量运行脚本

  • 要仅打印编译后的 Autograd 图形,请使用 TORCH_LOGS="compiled_autograd" python example.py

  • 要以性能为代价打印具有更多张量元数据和重新编译原因的图形,请使用 TORCH_LOGS="compiled_autograd_verbose" python example.py

重新运行上面的代码段,编译后的 Autograd 图形现在应该被记录到 stderr 中。某些图形节点将具有以 aot0_ 为前缀的名称,这些名称对应于先前在 AOTAutograd 后向图形 0 中提前编译的节点,例如,aot0_view_2 对应于 ID 为 0 的 AOT 后向图形的 view_2

stderr_output = """
DEBUG:torch._dynamo.compiled_autograd.__compiled_autograd_verbose:Cache miss due to new autograd node: torch::autograd::GraphRoot (NodeCall 0) with key size 39, previous key sizes=[]
DEBUG:torch._dynamo.compiled_autograd.__compiled_autograd_verbose:TRACED GRAPH
===== Compiled autograd graph =====
<eval_with_key>.4 class CompiledAutograd(torch.nn.Module):
   def forward(self, inputs, sizes, scalars, hooks):
      # No stacktrace found for following nodes
      aot0_tangents_1: "f32[][]cpu" = inputs[0]
      aot0_primals_3: "f32[10][1]cpu" = inputs[1]
      getitem_2: "f32[10][1]cpu" = inputs[2]
      getitem_3: "f32[10, 10][10, 1]cpu" = inputs[3];  inputs = None

         # File: /data/users/xmfan/a/pytorch/torch/_dynamo/compiled_autograd.py:483 in set_node_origin, code: CompiledFunctionBackward0 (NodeCall 1)
      aot0_expand: "f32[10][0]cpu" = torch.ops.aten.expand.default(aot0_tangents_1, [10]);  aot0_tangents_1 = None
      aot0_view_2: "f32[1, 10][0, 0]cpu" = torch.ops.aten.view.default(aot0_expand, [1, 10]);  aot0_expand = None
      aot0_permute_2: "f32[10, 1][0, 0]cpu" = torch.ops.aten.permute.default(aot0_view_2, [1, 0])
      aot0_select: "f32[1][0]cpu" = torch.ops.aten.select.int(aot0_permute_2, 0, 0)
      aot0_view: "f32[1, 10][10, 1]cpu" = torch.ops.aten.view.default(aot0_primals_3, [1, 10]);  aot0_primals_3 = None
      aot0_mul_3: "f32[1, 10][10, 1]cpu" = torch.ops.aten.mul.Tensor(aot0_select, aot0_view);  aot0_select = None
      aot0_select_1: "f32[1][0]cpu" = torch.ops.aten.select.int(aot0_permute_2, 0, 1)
      aot0_mul_4: "f32[1, 10][10, 1]cpu" = torch.ops.aten.mul.Tensor(aot0_select_1, aot0_view);  aot0_select_1 = None
      aot0_select_2: "f32[1][0]cpu" = torch.ops.aten.select.int(aot0_permute_2, 0, 2)
      aot0_mul_5: "f32[1, 10][10, 1]cpu" = torch.ops.aten.mul.Tensor(aot0_select_2, aot0_view);  aot0_select_2 = None
      aot0_select_3: "f32[1][0]cpu" = torch.ops.aten.select.int(aot0_permute_2, 0, 3)
      aot0_mul_6: "f32[1, 10][10, 1]cpu" = torch.ops.aten.mul.Tensor(aot0_select_3, aot0_view);  aot0_select_3 = None
      aot0_select_4: "f32[1][0]cpu" = torch.ops.aten.select.int(aot0_permute_2, 0, 4)
      aot0_mul_7: "f32[1, 10][10, 1]cpu" = torch.ops.aten.mul.Tensor(aot0_select_4, aot0_view);  aot0_select_4 = None
      aot0_select_5: "f32[1][0]cpu" = torch.ops.aten.select.int(aot0_permute_2, 0, 5)
      aot0_mul_8: "f32[1, 10][10, 1]cpu" = torch.ops.aten.mul.Tensor(aot0_select_5, aot0_view);  aot0_select_5 = None
      aot0_select_6: "f32[1][0]cpu" = torch.ops.aten.select.int(aot0_permute_2, 0, 6)
      aot0_mul_9: "f32[1, 10][10, 1]cpu" = torch.ops.aten.mul.Tensor(aot0_select_6, aot0_view);  aot0_select_6 = None
      aot0_select_7: "f32[1][0]cpu" = torch.ops.aten.select.int(aot0_permute_2, 0, 7)
      aot0_mul_10: "f32[1, 10][10, 1]cpu" = torch.ops.aten.mul.Tensor(aot0_select_7, aot0_view);  aot0_select_7 = None
      aot0_select_8: "f32[1][0]cpu" = torch.ops.aten.select.int(aot0_permute_2, 0, 8)
      aot0_mul_11: "f32[1, 10][10, 1]cpu" = torch.ops.aten.mul.Tensor(aot0_select_8, aot0_view);  aot0_select_8 = None
      aot0_select_9: "f32[1][0]cpu" = torch.ops.aten.select.int(aot0_permute_2, 0, 9);  aot0_permute_2 = None
      aot0_mul_12: "f32[1, 10][10, 1]cpu" = torch.ops.aten.mul.Tensor(aot0_select_9, aot0_view);  aot0_select_9 = aot0_view = None
      aot0_cat: "f32[10, 10][10, 1]cpu" = torch.ops.aten.cat.default([aot0_mul_3, aot0_mul_4, aot0_mul_5, aot0_mul_6, aot0_mul_7, aot0_mul_8, aot0_mul_9, aot0_mul_10, aot0_mul_11, aot0_mul_12]);  aot0_mul_3 = aot0_mul_4 = aot0_mul_5 = aot0_mul_6 = aot0_mul_7 = aot0_mul_8 = aot0_mul_9 = aot0_mul_10 = aot0_mul_11 = aot0_mul_12 = None
      aot0_permute_3: "f32[10, 10][1, 10]cpu" = torch.ops.aten.permute.default(aot0_cat, [1, 0]);  aot0_cat = None
      aot0_sum_3: "f32[1, 10][10, 1]cpu" = torch.ops.aten.sum.dim_IntList(aot0_view_2, [0], True);  aot0_view_2 = None
      aot0_view_3: "f32[10][1]cpu" = torch.ops.aten.view.default(aot0_sum_3, [10]);  aot0_sum_3 = None

         # File: /data/users/xmfan/a/pytorch/torch/_dynamo/compiled_autograd.py:483 in set_node_origin, code: torch::autograd::AccumulateGrad (NodeCall 2)
      accumulate_grad_ = torch.ops.inductor.accumulate_grad_.default(getitem_2, aot0_view_3);  getitem_2 = aot0_view_3 = accumulate_grad_ = None

         # File: /data/users/xmfan/a/pytorch/torch/_dynamo/compiled_autograd.py:483 in set_node_origin, code: CompiledFunctionBackward0 (NodeCall 1)
      aot0_permute_4: "f32[10, 10][10, 1]cpu" = torch.ops.aten.permute.default(aot0_permute_3, [1, 0]);  aot0_permute_3 = None

         # File: /data/users/xmfan/a/pytorch/torch/_dynamo/compiled_autograd.py:483 in set_node_origin, code: torch::autograd::AccumulateGrad (NodeCall 3)
      accumulate_grad__1 = torch.ops.inductor.accumulate_grad_.default(getitem_3, aot0_permute_4);  getitem_3 = aot0_permute_4 = accumulate_grad__1 = None
      _exec_final_callbacks_stub = torch__dynamo_external_utils__exec_final_callbacks_stub();  _exec_final_callbacks_stub = None
      return []
"""

注意

这是我们将对它调用 torch.compile 的图形,**不是**优化后的图形。编译后的 Autograd 本质上生成了一些未优化的 Python 代码来表示整个 C++ autograd 执行。

使用不同的标志编译前向和后向传递

您可以对两次编译使用不同的编译器配置,例如,即使前向存在图形中断,后向也可能是一个完整图形。

def train(model, x):
    model = torch.compile(model)
    loss = model(x).sum()
    torch._dynamo.config.compiled_autograd = True
    torch.compile(lambda: loss.backward(), fullgraph=True)()

或者,您可以使用上下文管理器,它将应用于其范围内的所有 autograd 调用。

def train(model, x):
   model = torch.compile(model)
   loss = model(x).sum()
   with torch._dynamo.compiled_autograd.enable(torch.compile(fullgraph=True)):
      loss.backward()

编译后的 Autograd 解决了 AOTAutograd 的一些限制

  1. 前向传递中的图形中断会导致后向传递中的图形中断

@torch.compile(backend="aot_eager")
def fn(x):
   # 1st graph
   temp = x + 10
   torch._dynamo.graph_break()
   # 2nd graph
   temp = temp + 10
   torch._dynamo.graph_break()
   # 3rd graph
   return temp.sum()

x = torch.randn(10, 10, requires_grad=True)
torch._dynamo.utils.counters.clear()
loss = fn(x)

# 1. base torch.compile
loss.backward(retain_graph=True)
assert(torch._dynamo.utils.counters["stats"]["unique_graphs"] == 3)
torch._dynamo.utils.counters.clear()

# 2. torch.compile with compiled autograd
with torch._dynamo.compiled_autograd.enable(torch.compile(backend="aot_eager")):
   loss.backward()

# single graph for the backward
assert(torch._dynamo.utils.counters["stats"]["unique_graphs"] == 1)

在第一个 torch.compile 案例中,我们看到由于编译后的函数 fn 中的 2 个图形中断而产生了 3 个后向图形。而在第二个使用编译后的 autograd 的 torch.compile 案例中,我们看到尽管存在图形中断,但仍然跟踪了一个完整的后向图形。

  1. 没有捕获后向钩子

@torch.compile(backend="aot_eager")
def fn(x):
   return x.sum()

x = torch.randn(10, 10, requires_grad=True)
x.register_hook(lambda grad: grad+10)
loss = fn(x)

with torch._dynamo.compiled_autograd.enable(torch.compile(backend="aot_eager")):
   loss.backward()

图形中应该有一个 call_hook 节点,Dynamo 稍后会将其内联到以下内容中

stderr_output = """
DEBUG:torch._dynamo.compiled_autograd.__compiled_autograd_verbose:Cache miss due to new autograd node: torch::autograd::GraphRoot (NodeCall 0) with key size 39, previous key sizes=[]
DEBUG:torch._dynamo.compiled_autograd.__compiled_autograd_verbose:TRACED GRAPH
===== Compiled autograd graph =====
<eval_with_key>.2 class CompiledAutograd(torch.nn.Module):
   def forward(self, inputs, sizes, scalars, hooks):
      ...
      getitem_2 = hooks[0];  hooks = None
      call_hook: "f32[10, 10][0, 0]cpu" = torch__dynamo_external_utils_call_hook(getitem_2, aot0_expand, hook_type = 'tensor_pre_hook');  getitem_2 = aot0_expand = None
      ...
"""

编译后的 Autograd 的常见重新编译原因

  1. 由于损失值的 autograd 结构发生了变化

torch._dynamo.config.compiled_autograd = True
x = torch.randn(10, requires_grad=True)
for op in [torch.add, torch.sub, torch.mul, torch.div]:
   loss = op(x, x).sum()
   torch.compile(lambda: loss.backward(), backend="eager")()

在上面的示例中,我们在每次迭代中调用不同的运算符,导致 loss 每次跟踪不同的 autograd 历史记录。您应该看到一些重新编译消息:**由于新的 autograd 节点导致缓存未命中**。

stderr_output = """
Cache miss due to new autograd node: torch::autograd::GraphRoot (NodeCall 0) with key size 39, previous key sizes=[]
...
Cache miss due to new autograd node: SubBackward0 (NodeCall 2) with key size 56, previous key sizes=[]
...
Cache miss due to new autograd node: MulBackward0 (NodeCall 2) with key size 71, previous key sizes=[]
...
Cache miss due to new autograd node: DivBackward0 (NodeCall 2) with key size 70, previous key sizes=[]
...
"""
  1. 由于张量形状发生了变化

torch._dynamo.config.compiled_autograd = True
for i in [10, 100, 10]:
   x = torch.randn(i, i, requires_grad=True)
   loss = x.sum()
   torch.compile(lambda: loss.backward(), backend="eager")()

在上面的示例中,x 形状发生了变化,编译后的 autograd 将在第一次更改后将 x 标记为动态形状张量。您应该看到重新编译消息:**由于形状发生变化导致缓存未命中**。

stderr_output = """
...
Cache miss due to changed shapes: marking size idx 0 of torch::autograd::GraphRoot (NodeCall 0) as dynamic
Cache miss due to changed shapes: marking size idx 1 of torch::autograd::AccumulateGrad (NodeCall 2) as dynamic
Cache miss due to changed shapes: marking size idx 2 of torch::autograd::AccumulateGrad (NodeCall 2) as dynamic
Cache miss due to changed shapes: marking size idx 3 of torch::autograd::AccumulateGrad (NodeCall 2) as dynamic
...
"""

结论

在本教程中,我们介绍了 torch.compile 与编译后的 autograd 的高级生态系统,编译后的 autograd 的基础知识以及一些常见的重新编译原因。请继续关注 dev-discuss 上的深入探讨。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源