• 文档 >
  • AOT 自动微分 - 如何使用和优化?
快捷方式

AOT 自动微分 - 如何使用和优化?

Open In Colab

背景

在本教程中,我们将学习如何使用 AOT 自动微分来加速深度学习模型的训练。

作为背景,AOT 自动微分是一个工具包,可帮助开发人员加速 PyTorch 上的训练。总体而言,它具有两个关键特性

  • AOT 自动微分提前跟踪前向和后向图。提前存在前向和后向图有助于进行联合图优化,例如重新计算或激活检查点。

  • AOT 自动微分提供了简单的机制,可以通过深度学习编译器(例如 NVFuser、NNC、TVM 等)编译提取的前向和后向图。

你将学到什么?

在本教程中,我们将看看如何将 AOT 自动微分与后端编译器结合使用,以加速 PyTorch 模型的训练。更具体地说,你将学习

  • 如何使用 AOT 自动微分?

  • AOT 自动微分如何使用后端编译器执行操作融合?

  • AOT 自动微分如何启用训练特定的优化,例如重新计算?

所以,让我们开始吧。

设置

让我们设置一个简单的模型。

import torch

def fn(a, b, c, d):
    x = a + b + c + d
    return x.cos().cos()
# Test that it works
a, b, c, d = [torch.randn(2, 4, requires_grad=True) for _ in range(4)]
ref = fn(a, b, c, d)
loss = ref.sum()
loss.backward()

使用 AOT 自动微分

现在,让我们使用 AOT 自动微分并查看提取的前向和后向图。在内部,AOT 使用 __torch_dispatch__ 基于跟踪的机制来提取前向和后向图,并将它们包装在 torch.Fx GraphModule 容器中。请注意,AOT 自动微分跟踪与通常的 Fx 符号跟踪不同。AOT 自动微分使用 Fx GraphModule 只是为了表示跟踪的图(而不是用于跟踪)。

然后,AOT 自动微分将这些前向和后向图发送到用户提供的编译器。所以,让我们编写一个编译器,它只打印图。

from functorch.compile import aot_function

# The compiler_fn is called after the forward and backward graphs are extracted.
# Here, we just print the code in the compiler_fn. Return of this function is a callable.
def compiler_fn(fx_module: torch.fx.GraphModule, _):
    print(fx_module.code)
    return fx_module

# Pass on the compiler_fn to the aot_function API
aot_print_fn = aot_function(fn, fw_compiler=compiler_fn, bw_compiler=compiler_fn)

# Run the aot_print_fn once to trigger the compilation and print the graphs
cloned_inputs = [x.clone().detach().requires_grad_(True) for x in (a, b, c, d)]
cloned_a, cloned_b, cloned_c, cloned_d = cloned_inputs
res = aot_print_fn(cloned_a, cloned_b, cloned_c, cloned_d)
res.sum().backward()
assert torch.allclose(ref, res)
def forward(self, primals_1, primals_2, primals_3, primals_4):
    add = torch.ops.aten.add(primals_1, primals_2);  primals_1 = primals_2 = None
    add_1 = torch.ops.aten.add(add, primals_3);  add = primals_3 = None
    add_2 = torch.ops.aten.add(add_1, primals_4);  add_1 = primals_4 = None
    cos = torch.ops.aten.cos(add_2)
    cos_1 = torch.ops.aten.cos(cos)
    return [cos_1, add_2, cos]
    



def forward(self, add_2, cos, tangents_1):
    sin = torch.ops.aten.sin(cos);  cos = None
    neg = torch.ops.aten.neg(sin);  sin = None
    mul = torch.ops.aten.mul(tangents_1, neg);  tangents_1 = neg = None
    sin_1 = torch.ops.aten.sin(add_2);  add_2 = None
    neg_1 = torch.ops.aten.neg(sin_1);  sin_1 = None
    mul_1 = torch.ops.aten.mul(mul, neg_1);  mul = neg_1 = None
    return [mul_1, mul_1, mul_1, mul_1]
    

上面的代码打印了前向和后向图的 Fx 图。你可以看到,除了前向传递的原始输入之外,前向图还输出了一些额外的张量。这些张量被保存用于后向传递以进行梯度计算。在讨论重新计算时,我们将回到它们。

操作融合

现在我们了解了如何使用 AOT 自动微分来打印前向和后向图,让我们使用 AOT 自动微分来使用一些实际的深度学习编译器。在本教程中,我们使用 PyTorch 神经网络编译器 (NNC) 为 CPU 设备执行逐点操作融合。对于 CUDA 设备,合适的替代方案是 NvFuser。所以,让我们使用 NNC

# AOT Autograd has a suite of already integrated backends. Lets import the NNC compiler backend - ts_compile
from functorch.compile import ts_compile

# Lets compile the forward and backward through ts_compile.
aot_nnc_fn = aot_function(fn, fw_compiler=ts_compile, bw_compiler=ts_compile)

# Correctness checking. Lets clone the input so that we can check grads.
cloned_inputs = [x.clone().detach().requires_grad_(True) for x in (a, b, c, d)]
cloned_a, cloned_b, cloned_c, cloned_d = cloned_inputs

res = aot_nnc_fn(*cloned_inputs)
loss = res.sum()
loss.backward()
assert torch.allclose(ref, res)
assert torch.allclose(a.grad, cloned_a.grad)
assert torch.allclose(b.grad, cloned_b.grad)
assert torch.allclose(c.grad, cloned_c.grad)
assert torch.allclose(d.grad, cloned_d.grad)

让我们对原始函数和 AOT 自动微分 + NNC 编译的函数进行基准测试。

# Lets write a function to benchmark the forward and backward pass
import time
import statistics

def bench(fn, args, prefix):
    warmup = 10
    iterations = 100

    for _ in range(warmup):
        ref = fn(*args)
        ref.sum().backward()
    
    fw_latencies = []
    bw_latencies = []
    for _ in range(iterations):
        for arg in args:
            arg.grad = None

        fw_begin = time.perf_counter()
        ref = fn(*args)
        fw_end = time.perf_counter()

        loss = ref.sum() 

        bw_begin = time.perf_counter()
        loss.backward()
        bw_end = time.perf_counter()

        fw_latencies.append(fw_end - fw_begin)
        bw_latencies.append(bw_end - bw_begin)
    
    avg_fw_latency = statistics.mean(fw_latencies) * 10**6
    avg_bw_latency = statistics.mean(bw_latencies) * 10**6
    print(prefix, "Fwd = " + str(avg_fw_latency) + " us", "Bwd = " + str(avg_bw_latency) + " us", sep=', ')
large_inputs = [torch.randn(1024, 2048, requires_grad=True) for _ in range(4)]

# Benchmark the Eager and AOT Autograd functions
bench(fn, large_inputs, "Eager")
bench(aot_nnc_fn, large_inputs, "AOT")
Eager, Fwd = 982.6959593920038 us, Bwd = 1899.7003795811906 us
AOT, Fwd = 734.2723174951971 us, Bwd = 831.1696897726506 us

在 NNC 的帮助下,AOT 自动微分加速了前向和后向传递。如果我们查看之前打印的图,所有操作都是逐点的。逐点操作是内存带宽绑定的,因此受益于操作融合。仔细查看这些数字,后向传递获得了更高的加速。这是因为前向传递必须为后向传递的梯度计算输出一些中间张量,从而阻止它保存一些内存读写操作。但是,后向图中不存在这种限制。

重新计算(也称为激活检查点)

重新计算(通常称为激活检查点)是一种技术,其中,我们不是将一些激活保存起来在反向传播中使用,而是在反向传播过程中重新计算它们。重新计算节省了内存,但我们付出了性能开销。

但是,在存在融合编译器的情况下,我们可以做得更好。我们可以重新计算融合友好的操作以节省内存,然后依靠融合编译器来融合重新计算的操作。这减少了内存和运行时间。请参考 讨论帖 了解更多详细信息。

在这里,我们使用 AOT 自动微分与 NNC 一起执行类似类型的重新计算。在 __torch_dispatch__ 跟踪结束时,AOT 自动微分有一个前向图和一个联合前向-后向图。然后,AOT 自动微分使用一个分区器来隔离前向和后向图。在上面的示例中,我们使用了一个默认分区器。对于这个实验,我们将使用另一个名为 min_cut_rematerialization_partition 的分区器来执行更智能的融合感知重新计算。分区器是可配置的,可以编写自己的分区器以将其插入 AOT 自动微分。

from functorch.compile import min_cut_rematerialization_partition

# Zero out the gradients so we can do a comparison later
a.grad, b.grad, c.grad, d.grad = (None,) * 4

# Lets set up the partitioner. Also set the fwd and bwd compilers to the printer function that we used earlier.
# This will show us how the recomputation has modified the graph.
aot_fn = aot_function(fn, fw_compiler=compiler_fn, bw_compiler=compiler_fn, partition_fn=min_cut_rematerialization_partition)
res = aot_fn(a, b, c, d).sum().backward()
def forward(self, primals_1, primals_2, primals_3, primals_4):
    add = torch.ops.aten.add(primals_1, primals_2);  primals_1 = primals_2 = None
    add_1 = torch.ops.aten.add(add, primals_3);  add = primals_3 = None
    add_2 = torch.ops.aten.add(add_1, primals_4);  add_1 = primals_4 = None
    cos = torch.ops.aten.cos(add_2)
    cos_1 = torch.ops.aten.cos(cos);  cos = None
    return [cos_1, add_2]
    



def forward(self, add_2, tangents_1):
    cos = torch.ops.aten.cos(add_2)
    sin = torch.ops.aten.sin(cos);  cos = None
    neg = torch.ops.aten.neg(sin);  sin = None
    mul = torch.ops.aten.mul(tangents_1, neg);  tangents_1 = neg = None
    sin_1 = torch.ops.aten.sin(add_2);  add_2 = None
    neg_1 = torch.ops.aten.neg(sin_1);  sin_1 = None
    mul_1 = torch.ops.aten.mul(mul, neg_1);  mul = neg_1 = None
    return [mul_1, mul_1, mul_1, mul_1]
    

我们可以看到,与默认分区器相比,前向传递现在输出的张量更少,并在后向传递中重新计算了一些操作。现在让我们尝试使用 NNC 编译器来执行操作融合(注意,我们还有一个包装函数 - memory_efficient_fusion,它在内部使用 min_cut_rematerialization_partition 和 Torchscript 编译器来实现与以下代码相同的效果)。

# Lets set up the partitioner and NNC compiler.
aot_recompute_nnc_fn = aot_function(fn, fw_compiler=ts_compile, bw_compiler=ts_compile, partition_fn=min_cut_rematerialization_partition)

# Correctness checking. Lets clone the input so that we can check grads.
cloned_inputs = [x.clone().detach().requires_grad_(True) for x in (a, b, c, d)]
cloned_a, cloned_b, cloned_c, cloned_d = cloned_inputs

res = aot_recompute_nnc_fn(*cloned_inputs)
loss = res.sum()
loss.backward()
assert torch.allclose(ref, res)
assert torch.allclose(a.grad, cloned_a.grad)
assert torch.allclose(b.grad, cloned_b.grad)
assert torch.allclose(c.grad, cloned_c.grad)
assert torch.allclose(d.grad, cloned_d.grad)

最后,让我们对不同的函数进行基准测试

bench(fn, large_inputs, "Eager")
bench(aot_nnc_fn, large_inputs, "AOT")
bench(aot_recompute_nnc_fn, large_inputs, "AOT_Recomp")
Eager, Fwd = 740.7676504226401 us, Bwd = 1560.5240693548694 us
AOT, Fwd = 713.8530415249988 us, Bwd = 909.1200679540634 us
AOT_Recomp, Fwd = 712.2249767417088 us, Bwd = 791.4606417762116 us

我们观察到,前向和后向延迟都比默认分区器(以及比 eager 好得多)有所提高。前向传递中较少的输出和后向传递中较少的输入,以及融合,允许更好的内存带宽利用率,从而导致进一步的加速。

实际使用

对于 CUDA 设备上的实际使用,我们已将 AOTAutograd 包装在一个方便的包装器中 - memory_efficient_fusion。在 GPU 上使用它进行融合!

from functorch.compile import memory_efficient_fusion

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源