AOT 自动微分 - 如何使用和优化?¶
背景¶
在本教程中,我们将学习如何使用 AOT 自动微分来加速深度学习模型的训练。
作为背景,AOT 自动微分是一个工具包,可帮助开发人员加速 PyTorch 上的训练。总体而言,它具有两个关键特性
AOT 自动微分提前跟踪前向和后向图。提前存在前向和后向图有助于进行联合图优化,例如重新计算或激活检查点。
AOT 自动微分提供了简单的机制,可以通过深度学习编译器(例如 NVFuser、NNC、TVM 等)编译提取的前向和后向图。
你将学到什么?¶
在本教程中,我们将看看如何将 AOT 自动微分与后端编译器结合使用,以加速 PyTorch 模型的训练。更具体地说,你将学习
如何使用 AOT 自动微分?
AOT 自动微分如何使用后端编译器执行操作融合?
AOT 自动微分如何启用训练特定的优化,例如重新计算?
所以,让我们开始吧。
设置¶
让我们设置一个简单的模型。
import torch
def fn(a, b, c, d):
x = a + b + c + d
return x.cos().cos()
# Test that it works
a, b, c, d = [torch.randn(2, 4, requires_grad=True) for _ in range(4)]
ref = fn(a, b, c, d)
loss = ref.sum()
loss.backward()
使用 AOT 自动微分¶
现在,让我们使用 AOT 自动微分并查看提取的前向和后向图。在内部,AOT 使用 __torch_dispatch__
基于跟踪的机制来提取前向和后向图,并将它们包装在 torch.Fx
GraphModule 容器中。请注意,AOT 自动微分跟踪与通常的 Fx 符号跟踪不同。AOT 自动微分使用 Fx GraphModule 只是为了表示跟踪的图(而不是用于跟踪)。
然后,AOT 自动微分将这些前向和后向图发送到用户提供的编译器。所以,让我们编写一个编译器,它只打印图。
from functorch.compile import aot_function
# The compiler_fn is called after the forward and backward graphs are extracted.
# Here, we just print the code in the compiler_fn. Return of this function is a callable.
def compiler_fn(fx_module: torch.fx.GraphModule, _):
print(fx_module.code)
return fx_module
# Pass on the compiler_fn to the aot_function API
aot_print_fn = aot_function(fn, fw_compiler=compiler_fn, bw_compiler=compiler_fn)
# Run the aot_print_fn once to trigger the compilation and print the graphs
cloned_inputs = [x.clone().detach().requires_grad_(True) for x in (a, b, c, d)]
cloned_a, cloned_b, cloned_c, cloned_d = cloned_inputs
res = aot_print_fn(cloned_a, cloned_b, cloned_c, cloned_d)
res.sum().backward()
assert torch.allclose(ref, res)
def forward(self, primals_1, primals_2, primals_3, primals_4):
add = torch.ops.aten.add(primals_1, primals_2); primals_1 = primals_2 = None
add_1 = torch.ops.aten.add(add, primals_3); add = primals_3 = None
add_2 = torch.ops.aten.add(add_1, primals_4); add_1 = primals_4 = None
cos = torch.ops.aten.cos(add_2)
cos_1 = torch.ops.aten.cos(cos)
return [cos_1, add_2, cos]
def forward(self, add_2, cos, tangents_1):
sin = torch.ops.aten.sin(cos); cos = None
neg = torch.ops.aten.neg(sin); sin = None
mul = torch.ops.aten.mul(tangents_1, neg); tangents_1 = neg = None
sin_1 = torch.ops.aten.sin(add_2); add_2 = None
neg_1 = torch.ops.aten.neg(sin_1); sin_1 = None
mul_1 = torch.ops.aten.mul(mul, neg_1); mul = neg_1 = None
return [mul_1, mul_1, mul_1, mul_1]
上面的代码打印了前向和后向图的 Fx 图。你可以看到,除了前向传递的原始输入之外,前向图还输出了一些额外的张量。这些张量被保存用于后向传递以进行梯度计算。在讨论重新计算时,我们将回到它们。
操作融合¶
现在我们了解了如何使用 AOT 自动微分来打印前向和后向图,让我们使用 AOT 自动微分来使用一些实际的深度学习编译器。在本教程中,我们使用 PyTorch 神经网络编译器 (NNC) 为 CPU 设备执行逐点操作融合。对于 CUDA 设备,合适的替代方案是 NvFuser。所以,让我们使用 NNC
# AOT Autograd has a suite of already integrated backends. Lets import the NNC compiler backend - ts_compile
from functorch.compile import ts_compile
# Lets compile the forward and backward through ts_compile.
aot_nnc_fn = aot_function(fn, fw_compiler=ts_compile, bw_compiler=ts_compile)
# Correctness checking. Lets clone the input so that we can check grads.
cloned_inputs = [x.clone().detach().requires_grad_(True) for x in (a, b, c, d)]
cloned_a, cloned_b, cloned_c, cloned_d = cloned_inputs
res = aot_nnc_fn(*cloned_inputs)
loss = res.sum()
loss.backward()
assert torch.allclose(ref, res)
assert torch.allclose(a.grad, cloned_a.grad)
assert torch.allclose(b.grad, cloned_b.grad)
assert torch.allclose(c.grad, cloned_c.grad)
assert torch.allclose(d.grad, cloned_d.grad)
让我们对原始函数和 AOT 自动微分 + NNC 编译的函数进行基准测试。
# Lets write a function to benchmark the forward and backward pass
import time
import statistics
def bench(fn, args, prefix):
warmup = 10
iterations = 100
for _ in range(warmup):
ref = fn(*args)
ref.sum().backward()
fw_latencies = []
bw_latencies = []
for _ in range(iterations):
for arg in args:
arg.grad = None
fw_begin = time.perf_counter()
ref = fn(*args)
fw_end = time.perf_counter()
loss = ref.sum()
bw_begin = time.perf_counter()
loss.backward()
bw_end = time.perf_counter()
fw_latencies.append(fw_end - fw_begin)
bw_latencies.append(bw_end - bw_begin)
avg_fw_latency = statistics.mean(fw_latencies) * 10**6
avg_bw_latency = statistics.mean(bw_latencies) * 10**6
print(prefix, "Fwd = " + str(avg_fw_latency) + " us", "Bwd = " + str(avg_bw_latency) + " us", sep=', ')
large_inputs = [torch.randn(1024, 2048, requires_grad=True) for _ in range(4)]
# Benchmark the Eager and AOT Autograd functions
bench(fn, large_inputs, "Eager")
bench(aot_nnc_fn, large_inputs, "AOT")
Eager, Fwd = 982.6959593920038 us, Bwd = 1899.7003795811906 us
AOT, Fwd = 734.2723174951971 us, Bwd = 831.1696897726506 us
在 NNC 的帮助下,AOT 自动微分加速了前向和后向传递。如果我们查看之前打印的图,所有操作都是逐点的。逐点操作是内存带宽绑定的,因此受益于操作融合。仔细查看这些数字,后向传递获得了更高的加速。这是因为前向传递必须为后向传递的梯度计算输出一些中间张量,从而阻止它保存一些内存读写操作。但是,后向图中不存在这种限制。
重新计算(也称为激活检查点)¶
重新计算(通常称为激活检查点)是一种技术,其中,我们不是将一些激活保存起来在反向传播中使用,而是在反向传播过程中重新计算它们。重新计算节省了内存,但我们付出了性能开销。
但是,在存在融合编译器的情况下,我们可以做得更好。我们可以重新计算融合友好的操作以节省内存,然后依靠融合编译器来融合重新计算的操作。这减少了内存和运行时间。请参考 讨论帖 了解更多详细信息。
在这里,我们使用 AOT 自动微分与 NNC 一起执行类似类型的重新计算。在 __torch_dispatch__
跟踪结束时,AOT 自动微分有一个前向图和一个联合前向-后向图。然后,AOT 自动微分使用一个分区器来隔离前向和后向图。在上面的示例中,我们使用了一个默认分区器。对于这个实验,我们将使用另一个名为 min_cut_rematerialization_partition
的分区器来执行更智能的融合感知重新计算。分区器是可配置的,可以编写自己的分区器以将其插入 AOT 自动微分。
from functorch.compile import min_cut_rematerialization_partition
# Zero out the gradients so we can do a comparison later
a.grad, b.grad, c.grad, d.grad = (None,) * 4
# Lets set up the partitioner. Also set the fwd and bwd compilers to the printer function that we used earlier.
# This will show us how the recomputation has modified the graph.
aot_fn = aot_function(fn, fw_compiler=compiler_fn, bw_compiler=compiler_fn, partition_fn=min_cut_rematerialization_partition)
res = aot_fn(a, b, c, d).sum().backward()
def forward(self, primals_1, primals_2, primals_3, primals_4):
add = torch.ops.aten.add(primals_1, primals_2); primals_1 = primals_2 = None
add_1 = torch.ops.aten.add(add, primals_3); add = primals_3 = None
add_2 = torch.ops.aten.add(add_1, primals_4); add_1 = primals_4 = None
cos = torch.ops.aten.cos(add_2)
cos_1 = torch.ops.aten.cos(cos); cos = None
return [cos_1, add_2]
def forward(self, add_2, tangents_1):
cos = torch.ops.aten.cos(add_2)
sin = torch.ops.aten.sin(cos); cos = None
neg = torch.ops.aten.neg(sin); sin = None
mul = torch.ops.aten.mul(tangents_1, neg); tangents_1 = neg = None
sin_1 = torch.ops.aten.sin(add_2); add_2 = None
neg_1 = torch.ops.aten.neg(sin_1); sin_1 = None
mul_1 = torch.ops.aten.mul(mul, neg_1); mul = neg_1 = None
return [mul_1, mul_1, mul_1, mul_1]
我们可以看到,与默认分区器相比,前向传递现在输出的张量更少,并在后向传递中重新计算了一些操作。现在让我们尝试使用 NNC 编译器来执行操作融合(注意,我们还有一个包装函数 - memory_efficient_fusion
,它在内部使用 min_cut_rematerialization_partition
和 Torchscript 编译器来实现与以下代码相同的效果)。
# Lets set up the partitioner and NNC compiler.
aot_recompute_nnc_fn = aot_function(fn, fw_compiler=ts_compile, bw_compiler=ts_compile, partition_fn=min_cut_rematerialization_partition)
# Correctness checking. Lets clone the input so that we can check grads.
cloned_inputs = [x.clone().detach().requires_grad_(True) for x in (a, b, c, d)]
cloned_a, cloned_b, cloned_c, cloned_d = cloned_inputs
res = aot_recompute_nnc_fn(*cloned_inputs)
loss = res.sum()
loss.backward()
assert torch.allclose(ref, res)
assert torch.allclose(a.grad, cloned_a.grad)
assert torch.allclose(b.grad, cloned_b.grad)
assert torch.allclose(c.grad, cloned_c.grad)
assert torch.allclose(d.grad, cloned_d.grad)
最后,让我们对不同的函数进行基准测试
bench(fn, large_inputs, "Eager")
bench(aot_nnc_fn, large_inputs, "AOT")
bench(aot_recompute_nnc_fn, large_inputs, "AOT_Recomp")
Eager, Fwd = 740.7676504226401 us, Bwd = 1560.5240693548694 us
AOT, Fwd = 713.8530415249988 us, Bwd = 909.1200679540634 us
AOT_Recomp, Fwd = 712.2249767417088 us, Bwd = 791.4606417762116 us
我们观察到,前向和后向延迟都比默认分区器(以及比 eager 好得多)有所提高。前向传递中较少的输出和后向传递中较少的输入,以及融合,允许更好的内存带宽利用率,从而导致进一步的加速。
实际使用¶
对于 CUDA 设备上的实际使用,我们已将 AOTAutograd 包装在一个方便的包装器中 - memory_efficient_fusion
。在 GPU 上使用它进行融合!
from functorch.compile import memory_efficient_fusion