使用 Minifier¶
我们有一个非常方便的测试用例 minifier,它具有以下接口
def minifier(fail_f: fx.GraphModule, inps, module_fails):
"""
Minimizes a FX graph with given inputs, such that the resulting FX graph still returns True for module_fails.
Does 2 main strategies:
1. Truncates suffix: Removes some suffix from the graph and sets a new output.
2. Delta Debugging: Tries replacing half of the graph with inputs. If fails,
tries replacing quarter of the graph, etc.
>>> failing_function = fx.symbolic_trace(f)
>>> minimize(failing_function, [torch.randn(5)], lambda fx_g, inps: fx_g(*inps))
note: module_fails returns True if it fails.
...
具体来说,它接收你的 FX 图,并尝试使用以下 4 种策略对其进行压缩(同时检查生成的图是否仍然返回 True 用于 module_fails
),直到无法再压缩为止。
截断后缀:给定一个 FX 图,它尝试从图中删除某些后缀。例如,给定以下内容
def f(a):
b = x * 2
c = b + 3
d = c / 4
return d
它可能会尝试截断后缀,得到
def f(a):
b = x * 2
c = b + 3
return c
它以二分搜索的方式尝试这样做,尝试删除最后 1/2,然后 3/4,1/4,然后 7/8,5/8,3/8……
增量调试:当然,删除后缀并不总是足以压缩图形。如果错误是由第一条指令引起的怎么办?因此,我们采用了一种受增量调试启发的方案 - 我们尝试删除图形的中间节点。与后缀不同,删除的节点仍然存在依赖关系。因此,我们不完全删除它们,而是将它们提升为输入。例如,给定上述示例
def f(a):
b = x * 2
c = b + 3
d = c / 4
return d
我们可能会删除一个中间节点(例如,在这种情况下为 c)。
def f(a, c):
b = x * 2
d = c / 4
return d
最后,还有两种辅助策略 - 消除死代码和删除未使用的输入。这些策略不言而喻。
因此,让我们看一个玩具示例。假设我们的图如果包含“乘法”则会失败。让我们创建一个失败的图。
import torch
import torch.fx as fx
from functorch.compile import minifier
def failing_f(x, y):
y = torch.ops.aten.div(x, y)
x = torch.ops.aten.add(x, 3)
x = torch.ops.aten.mul(x, y)
return torch.ops.aten.sub(x, y)
inps = [torch.randn(3), torch.randn(3)]
def pass_checker(fx_g, inps):
return (torch.ops.aten.mul in set([i.target for i in fx_g.graph.nodes]))
min_f, inps = minifier(fx.symbolic_trace(failing_f), inps, pass_checker)
[W OperatorEntry.cpp:133] Warning: Overriding a previously registered kernel for the same operator and the same dispatch key
operator: aten::multiply.Tensor(Tensor self, Tensor other) -> (Tensor)
registered at aten/src/ATen/RegisterSchema.cpp:6
dispatch key: FuncTorchBatched
previous kernel: registered at aten/src/ATen/RegisterCompositeImplicitAutograd.cpp:10338
new kernel: registered at /fsx/users/chilli/work/functorch/functorch/csrc/BatchRulesDecompositions.cpp:108 (function registerKernel)
Started off with 7 nodes
###################
Current size: 7
###################
Strategy: Remove suffix
SUCCESS: Removed [4:7)
###################
Current size: 6
###################
Strategy: Delta Debugging
SUCCESS: Removed (0:4] - Went from 2 placeholders to 4
###################
Current size: 6
###################
Strategy: Remove unused inputs
SUCCESS: Went from 4 inputs to 2 inputs
###################
Current size: 4
###################
Strategy: Remove suffix
FAIL: Could not remove suffix
Strategy: Delta Debugging
FAIL: Could not remove prefix
inps = [(torch.Size([3]), torch.float32), (torch.Size([3]), torch.float32)]
inps = [torch.zeros(())] + [torch.ones(shape, dtype=dtype, device='cuda') for (shape, dtype) in inps]
def forward(self, div, add):
mul = torch.ops.aten.mul(add, div); add = div = None
return (mul,)
f = torch.jit.script(forward)
with torch.jit.fuser("fuser2"):
for _ in range(5):
f(*inps)
瞧!我们的图现在是一个仍然失败的最小示例。
由于此 minifier 的主要用例(目前)是用于 NVFuser 重现,因此为了方便起见,我们打印出一个字符串,该字符串创建了一个自包含的重现以使用 NVFuser 运行压缩后的图。
请注意,在实践中,我们提供了 2 个主要的“图检查器” - check_nvfuser_subprocess
和 check_nvfuser_correctness_subprocess
。这些分别用于检查错误和正确性(即结果是否与 eager 匹配)。这些可以像这样使用
from functorch.compile import minifier, check_nvfuser_subprocess, check_nvfuser_correctness_subprocess
minifier(failing_graph, inps, check_nvfuser_subprocess)
但是,假设你正在使用 AOTAutograd,还有一个问题 - 你如何首先获取 FX 图以传递给 minifier?一种可能的方法是简单地使用 print_compile
。
from functorch.compile import aot_function
from functorch.compile import print_compile
# Or...
def print_compile(fx_g, _):
print(fx_g.code)
return fx_g
def foo(x):
return x.cos().cos()
inp = torch.randn(3, requires_grad=True)
aot_function(foo, print_compile)(inp)
def forward(self, primals_1):
cos = torch.ops.aten.cos(primals_1)
cos_1 = torch.ops.aten.cos(cos)
return [cos_1, primals_1, cos]
def forward(self, primals_1, cos, tangents_1):
sin = torch.ops.aten.sin(cos); cos = None
neg = torch.ops.aten.neg(sin); sin = None
mul = torch.ops.aten.mul(tangents_1, neg); tangents_1 = neg = None
sin_1 = torch.ops.aten.sin(primals_1); primals_1 = None
neg_1 = torch.ops.aten.neg(sin_1); sin_1 = None
mul_1 = torch.ops.aten.mul(mul, neg_1); mul = neg_1 = None
return [mul_1]
tensor([0.6062, 0.9982, 0.6474], grad_fn=<CompiledFunctionBackward>)
但是,这不会提供输入,也不会处理可能保存在图中的任何张量常量。为了解决这个问题,我们还有一个名为 debug_compile
的“编译器”。它只是打印出一个可以复制粘贴并从另一个文件中运行的字符串。它利用 FX 的 to_folder
功能将图序列化到磁盘,以及任何常量。
你可以将其应用于 fw_compiler
以转储前向图或 bw_compiler
以转储后向图。
from functorch.compile import memory_efficient_fusion, debug_compile
memory_efficient_fusion(foo, bw_compiler=debug_compile)(inp)
##############################################################
# To minimize FX graph, copy and paste the below and run it #
##############################################################
import torch
import torch.fx as fx
from functorch.compile import minifier, check_nvfuser_subprocess, check_nvfuser_correctness_subprocess
inps = [(torch.Size([3]), torch.float32), (torch.Size([3]), torch.float32)]
inps = [torch.ones(shape, dtype=dtype, device='cuda') for (shape, dtype) in inps]
from foo import FxModule
mod = FxModule().cuda()
with torch.jit.fuser("fuser2"):
# check_nvfuser_subprocess can be replaced with check_nvfuser_correctness_subprocess
minifier(fx.symbolic_trace(mod), inps, check_nvfuser_subprocess)
tensor([0.6062, 0.9982, 0.6474], grad_fn=<CompiledFunctionBackward>)
因此,让我们复制粘贴它并看看它是如何工作的 - 请注意,我做了一些小的修改以在 CPU 上运行并使用之前的“如果图中存在乘法则失败”检查器。
import torch
import torch.fx as fx
from functorch.compile import minifier, check_nvfuser_subprocess, check_nvfuser_correctness_subprocess
inps = [(torch.Size([3]), torch.float32), (torch.Size([3]), torch.float32)]
inps = [torch.ones(shape, dtype=dtype) for (shape, dtype) in inps]
from foo import FxModule
mod = FxModule()
minifier(fx.symbolic_trace(mod), inps, pass_checker)
Started off with 10 nodes
###################
Current size: 10
###################
Strategy: Remove suffix
SUCCESS: Removed [6:10)
###################
Current size: 8
###################
Strategy: Delta Debugging
SUCCESS: Removed (0:4] - Went from 2 placeholders to 4
###################
Current size: 8
###################
Strategy: Remove unused inputs
SUCCESS: Went from 4 inputs to 3 inputs
###################
Current size: 7
###################
Strategy: Remove suffix
SUCCESS: Removed [4:7)
###################
Current size: 6
###################
Strategy: Remove unused inputs
SUCCESS: Went from 3 inputs to 2 inputs
###################
Current size: 5
###################
Strategy: Delta Debugging
SUCCESS: Removed (2:3] - Went from 2 placeholders to 3
###################
Current size: 5
###################
Strategy: Remove unused inputs
SUCCESS: Went from 3 inputs to 2 inputs
###################
Current size: 4
###################
Strategy: Remove suffix
FAIL: Could not remove suffix
Strategy: Delta Debugging
FAIL: Could not remove prefix
inps = [(torch.Size([3]), torch.float32), (torch.Size([3]), torch.float32)]
inps = [torch.zeros(())] + [torch.ones(shape, dtype=dtype, device='cuda') for (shape, dtype) in inps]
def forward(self, tangents_1, neg):
mul = torch.ops.aten.mul(tangents_1, neg); tangents_1 = neg = None
return (mul,)
f = torch.jit.script(forward)
with torch.jit.fuser("fuser2"):
for _ in range(5):
f(*inps)
(GraphModule(), [tensor([1., 1., 1.]), tensor([-0.5144, -0.5144, -0.5144])])
希望这对你有所帮助 :)