快捷方式

使用 Minifier

我们有一个非常方便的测试用例 minifier,它具有以下接口

def minifier(fail_f: fx.GraphModule, inps, module_fails):
    """
    Minimizes a FX graph with given inputs, such that the resulting FX graph still returns True for module_fails.

    Does 2 main strategies:
    1. Truncates suffix: Removes some suffix from the graph and sets a new output.
    2. Delta Debugging: Tries replacing half of the graph with inputs. If fails,
        tries replacing quarter of the graph, etc.

    >>> failing_function = fx.symbolic_trace(f)
    >>> minimize(failing_function, [torch.randn(5)], lambda fx_g, inps: fx_g(*inps))

    note: module_fails returns True if it fails.
    ...

具体来说,它接收你的 FX 图,并尝试使用以下 4 种策略对其进行压缩(同时检查生成的图是否仍然返回 True 用于 module_fails),直到无法再压缩为止。

  1. 截断后缀:给定一个 FX 图,它尝试从图中删除某些后缀。例如,给定以下内容

def f(a):
    b = x * 2
    c = b + 3
    d = c / 4
    return d

它可能会尝试截断后缀,得到

def f(a):
    b = x * 2
    c = b + 3
    return c

它以二分搜索的方式尝试这样做,尝试删除最后 1/2,然后 3/4,1/4,然后 7/8,5/8,3/8……

  1. 增量调试:当然,删除后缀并不总是足以压缩图形。如果错误是由第一条指令引起的怎么办?因此,我们采用了一种受增量调试启发的方案 - 我们尝试删除图形的中间节点。与后缀不同,删除的节点仍然存在依赖关系。因此,我们不完全删除它们,而是将它们提升为输入。例如,给定上述示例

def f(a):
    b = x * 2
    c = b + 3
    d = c / 4
    return d

我们可能会删除一个中间节点(例如,在这种情况下为 c)。

def f(a, c):
    b = x * 2
    d = c / 4
    return d

最后,还有两种辅助策略 - 消除死代码和删除未使用的输入。这些策略不言而喻。

因此,让我们看一个玩具示例。假设我们的图如果包含“乘法”则会失败。让我们创建一个失败的图。

import torch
import torch.fx as fx
from functorch.compile import minifier

def failing_f(x, y):
    y = torch.ops.aten.div(x, y)
    x = torch.ops.aten.add(x, 3)
    x = torch.ops.aten.mul(x, y)
    return torch.ops.aten.sub(x, y)

inps = [torch.randn(3), torch.randn(3)]

def pass_checker(fx_g, inps):
    return (torch.ops.aten.mul in set([i.target for i in fx_g.graph.nodes]))

min_f, inps = minifier(fx.symbolic_trace(failing_f), inps, pass_checker)
[W OperatorEntry.cpp:133] Warning: Overriding a previously registered kernel for the same operator and the same dispatch key
  operator: aten::multiply.Tensor(Tensor self, Tensor other) -> (Tensor)
    registered at aten/src/ATen/RegisterSchema.cpp:6
  dispatch key: FuncTorchBatched
  previous kernel: registered at aten/src/ATen/RegisterCompositeImplicitAutograd.cpp:10338
       new kernel: registered at /fsx/users/chilli/work/functorch/functorch/csrc/BatchRulesDecompositions.cpp:108 (function registerKernel)
Started off with 7 nodes
###################
Current size: 7
###################
Strategy: Remove suffix

SUCCESS: Removed [4:7)

###################
Current size: 6
###################
Strategy: Delta Debugging
SUCCESS: Removed (0:4] - Went from 2 placeholders to 4

###################
Current size: 6
###################
Strategy: Remove unused inputs
SUCCESS: Went from 4 inputs to 2 inputs

###################
Current size: 4
###################
Strategy: Remove suffix
FAIL: Could not remove suffix
Strategy: Delta Debugging
FAIL: Could not remove prefix

inps = [(torch.Size([3]), torch.float32), (torch.Size([3]), torch.float32)]
inps = [torch.zeros(())] + [torch.ones(shape, dtype=dtype, device='cuda') for (shape, dtype) in inps]



def forward(self, div, add):
    mul = torch.ops.aten.mul(add, div);  add = div = None
    return (mul,)
    
f = torch.jit.script(forward)
with torch.jit.fuser("fuser2"):
  for _ in range(5):
    f(*inps)

瞧!我们的图现在是一个仍然失败的最小示例。

由于此 minifier 的主要用例(目前)是用于 NVFuser 重现,因此为了方便起见,我们打印出一个字符串,该字符串创建了一个自包含的重现以使用 NVFuser 运行压缩后的图。

请注意,在实践中,我们提供了 2 个主要的“图检查器” - check_nvfuser_subprocesscheck_nvfuser_correctness_subprocess。这些分别用于检查错误和正确性(即结果是否与 eager 匹配)。这些可以像这样使用

from functorch.compile import minifier, check_nvfuser_subprocess, check_nvfuser_correctness_subprocess
minifier(failing_graph, inps, check_nvfuser_subprocess)

但是,假设你正在使用 AOTAutograd,还有一个问题 - 你如何首先获取 FX 图以传递给 minifier?一种可能的方法是简单地使用 print_compile

from functorch.compile import aot_function

from functorch.compile import print_compile
# Or...
def print_compile(fx_g, _):
    print(fx_g.code)
    return fx_g

def foo(x):
    return x.cos().cos()
inp = torch.randn(3, requires_grad=True)
aot_function(foo, print_compile)(inp)
def forward(self, primals_1):
    cos = torch.ops.aten.cos(primals_1)
    cos_1 = torch.ops.aten.cos(cos)
    return [cos_1, primals_1, cos]
    



def forward(self, primals_1, cos, tangents_1):
    sin = torch.ops.aten.sin(cos);  cos = None
    neg = torch.ops.aten.neg(sin);  sin = None
    mul = torch.ops.aten.mul(tangents_1, neg);  tangents_1 = neg = None
    sin_1 = torch.ops.aten.sin(primals_1);  primals_1 = None
    neg_1 = torch.ops.aten.neg(sin_1);  sin_1 = None
    mul_1 = torch.ops.aten.mul(mul, neg_1);  mul = neg_1 = None
    return [mul_1]
    
tensor([0.6062, 0.9982, 0.6474], grad_fn=<CompiledFunctionBackward>)

但是,这不会提供输入,也不会处理可能保存在图中的任何张量常量。为了解决这个问题,我们还有一个名为 debug_compile 的“编译器”。它只是打印出一个可以复制粘贴并从另一个文件中运行的字符串。它利用 FX 的 to_folder 功能将图序列化到磁盘,以及任何常量。

你可以将其应用于 fw_compiler 以转储前向图或 bw_compiler 以转储后向图。

from functorch.compile import memory_efficient_fusion, debug_compile

memory_efficient_fusion(foo, bw_compiler=debug_compile)(inp)
##############################################################
# To minimize FX graph, copy and paste the below and run it  #
##############################################################

import torch
import torch.fx as fx
from functorch.compile import minifier, check_nvfuser_subprocess, check_nvfuser_correctness_subprocess

inps = [(torch.Size([3]), torch.float32), (torch.Size([3]), torch.float32)]
inps = [torch.ones(shape, dtype=dtype, device='cuda') for (shape, dtype) in inps]
from foo import FxModule
mod = FxModule().cuda()

with torch.jit.fuser("fuser2"):
  # check_nvfuser_subprocess can be replaced with check_nvfuser_correctness_subprocess
  minifier(fx.symbolic_trace(mod), inps, check_nvfuser_subprocess)
tensor([0.6062, 0.9982, 0.6474], grad_fn=<CompiledFunctionBackward>)

因此,让我们复制粘贴它并看看它是如何工作的 - 请注意,我做了一些小的修改以在 CPU 上运行并使用之前的“如果图中存在乘法则失败”检查器。

import torch
import torch.fx as fx
from functorch.compile import minifier, check_nvfuser_subprocess, check_nvfuser_correctness_subprocess

inps = [(torch.Size([3]), torch.float32), (torch.Size([3]), torch.float32)]
inps = [torch.ones(shape, dtype=dtype) for (shape, dtype) in inps]
from foo import FxModule
mod = FxModule()

minifier(fx.symbolic_trace(mod), inps, pass_checker)
Started off with 10 nodes
###################
Current size: 10
###################
Strategy: Remove suffix

SUCCESS: Removed [6:10)

###################
Current size: 8
###################
Strategy: Delta Debugging
SUCCESS: Removed (0:4] - Went from 2 placeholders to 4

###################
Current size: 8
###################
Strategy: Remove unused inputs
SUCCESS: Went from 4 inputs to 3 inputs

###################
Current size: 7
###################
Strategy: Remove suffix

SUCCESS: Removed [4:7)

###################
Current size: 6
###################
Strategy: Remove unused inputs
SUCCESS: Went from 3 inputs to 2 inputs

###################
Current size: 5
###################
Strategy: Delta Debugging
SUCCESS: Removed (2:3] - Went from 2 placeholders to 3

###################
Current size: 5
###################
Strategy: Remove unused inputs
SUCCESS: Went from 3 inputs to 2 inputs

###################
Current size: 4
###################
Strategy: Remove suffix
FAIL: Could not remove suffix
Strategy: Delta Debugging
FAIL: Could not remove prefix

inps = [(torch.Size([3]), torch.float32), (torch.Size([3]), torch.float32)]
inps = [torch.zeros(())] + [torch.ones(shape, dtype=dtype, device='cuda') for (shape, dtype) in inps]



def forward(self, tangents_1, neg):
    mul = torch.ops.aten.mul(tangents_1, neg);  tangents_1 = neg = None
    return (mul,)
    
f = torch.jit.script(forward)
with torch.jit.fuser("fuser2"):
  for _ in range(5):
    f(*inps)
(GraphModule(), [tensor([1., 1., 1.]), tensor([-0.5144, -0.5144, -0.5144])])

希望这对你有所帮助 :)

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源