自定义后端¶
概述¶
torch.compile
提供了一种简单的方法,使用户能够定义自定义后端。
后端函数具有以下约定 (gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]) -> Callable
。
在跟踪 FX 图之后,TorchDynamo(torch.compile
的图跟踪组件)可以调用后端函数,并期望返回一个与跟踪的 FX 图等效的编译函数。返回的可调用对象应该与传递到后端的原始 torch.fx.GraphModule
的 forward
函数具有相同的约定:(*args: torch.Tensor) -> List[torch.Tensor]
。
为了让 TorchDynamo 调用您的后端,请将您的后端函数作为 backend
关键字参数传递给 torch.compile
。例如,
import torch
def my_custom_backend(gm, example_inputs):
return gm.forward
def f(...):
...
f_opt = torch.compile(f, backend=my_custom_backend)
@torch.compile(backend=my_custom_backend)
def g(...):
...
有关更多示例,请参见下文。
注册自定义后端¶
您可以使用 register_backend
装饰器注册您的后端,例如,
from torch._dynamo import register_backend
@register_backend
def my_compiler(gm, example_inputs):
...
除了 register_backend
装饰器之外,如果您的后端在另一个 Python 包中,您还可以通过 Python 包的入口点注册您的后端,这提供了一种方法,使一个包可以为另一个包注册一个插件。
提示
您可以了解有关 entry_points
的更多信息,请参阅 Python 打包文档。
要通过 entry_points
注册您的后端,您可以在包的 setup.py
文件中将您的后端函数添加到 torch_dynamo_backends
入口点组,如下所示
...
setup(
...
'torch_dynamo_backends': [
'my_compiler = your_module.submodule:my_compiler',
]
...
)
请将 my_compiler
(在 =
之前)替换为您的后端名称,并将 =
之后的部分替换为您的后端函数的模块和函数名。安装包后,入口点将添加到您的 Python 环境中。当您调用 torch.compile(model, backend="my_compiler")
时,PyTorch 首先会搜索使用 register_backend
注册的名称为 my_compiler
的后端。如果未找到,它将继续搜索通过 entry_points
注册的所有后端。
注册具有两个目的
您可以将包含后端函数名称的字符串传递给
torch.compile
,而不是函数本身,例如torch.compile(model, backend="my_compiler")
。它对于与 缩小器 一起使用是必需的。缩小器生成的任何代码都必须调用注册后端函数的代码,通常通过
import
语句。
AOTAutograd 后的自定义后端¶
可以定义由 AOTAutograd 而不是 TorchDynamo 调用的自定义后端。这对于以下两个主要原因很有用
用户可以定义支持模型训练的后端,因为 AOTAutograd 可以生成用于编译的反向图。
AOTAutograd 生成由 核心 Aten 运算符 组成的 FX 图。因此,自定义后端只需要支持核心 Aten 运算符集,而核心 Aten 运算符集远小于整个 torch/Aten 运算符集。
使用 torch._dynamo.backends.common.aot_autograd
包装您的后端,并使用 torch.compile
以及 backend
关键字参数(如前所述)。由 aot_autograd
包装的后端函数应该与以前具有相同的约定。
后端函数通过 fw_compiler
(前向编译器)或 bw_compiler
(反向编译器)关键字参数传递给 aot_autograd
。如果未指定 bw_compiler
,则反向编译函数将默认为前向编译函数。
需要注意的是,AOTAutograd 要求后端返回的编译函数是“已封装”的。这可以通过使用 functorch.compile.make_boxed_func
包装编译函数来完成。
例如,
from torch._dynamo.backends.common import aot_autograd
from functorch.compile import make_boxed_func
def my_compiler(gm, example_inputs):
return make_boxed_func(gm.forward)
my_backend = aot_autograd(fw_compiler=my_compiler) # bw_compiler=my_compiler
model_opt = torch.compile(model, backend=my_backend)
示例¶
调试后端¶
如果您想更好地了解编译期间发生了什么,您可以创建一个自定义编译器(在本节中称为后端),它将打印提取自 Dynamo 字节码分析的 fx GraphModule
的格式化输出,并返回一个 forward()
可调用对象。
例如
from typing import List
import torch
def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
print("my_compiler() called with FX graph:")
gm.graph.print_tabular()
return gm.forward # return a python callable
@torch.compile(backend=my_compiler)
def fn(x, y):
a = torch.cos(x)
b = torch.sin(y)
return a + b
fn(torch.randn(10), torch.randn(10))
运行上述示例将产生以下输出
my_compiler() called with FX graph:
opcode name target args kwargs
------------- ------ ------------------------------------------------------ ---------- --------
placeholder x x () {}
placeholder y y () {}
call_function cos <built-in method cos of type object at 0x7f1a894649a8> (x,) {}
call_function sin <built-in method sin of type object at 0x7f1a894649a8> (y,) {}
call_function add <built-in function add> (cos, sin) {}
output output output ((add,),) {}
这适用于 torch.nn.Module
,如下所示
from typing import List
import torch
def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
print("my_compiler() called with FX graph:")
gm.graph.print_tabular()
return gm.forward # return a python callable
class MockModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.relu = torch.nn.ReLU()
def forward(self, x):
return self.relu(torch.cos(x))
mod = MockModule()
optimized_mod = torch.compile(mod, backend=my_compiler)
optimized_mod(torch.randn(10))
让我们再看一个带有控制流的例子
from typing import List
import torch
def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
print("my_compiler() called with FX graph:")
gm.graph.print_tabular()
return gm.forward # return a python callable
@torch.compile(backend=my_compiler)
def toy_example(a, b):
x = a / (torch.abs(a) + 1)
if b.sum() < 0:
b = b * -1
return x * b
for _ in range(100):
toy_example(torch.randn(10), torch.randn(10))
运行此示例将产生以下输出
my_compiler() called with FX graph:
opcode name target args kwargs
------------- ------- ------------------------------------------------------ ---------------- --------
placeholder a a () {}
placeholder b b () {}
call_function abs_1 <built-in method abs of type object at 0x7f8d259298a0> (a,) {}
call_function add <built-in function add> (abs_1, 1) {}
call_function truediv <built-in function truediv> (a, add) {}
call_method sum_1 sum (b,) {}
call_function lt <built-in function lt> (sum_1, 0) {}
output output output ((truediv, lt),) {}
my_compiler() called with FX graph:
opcode name target args kwargs
------------- ------ ----------------------- ----------- --------
placeholder b b () {}
placeholder x x () {}
call_function mul <built-in function mul> (b, -1) {}
call_function mul_1 <built-in function mul> (x, mul) {}
output output output ((mul_1,),) {}
my_compiler() called with FX graph:
opcode name target args kwargs
------------- ------ ----------------------- --------- --------
placeholder b b () {}
placeholder x x () {}
call_function mul <built-in function mul> (x, b) {}
output output output ((mul,),) {}
The order of the last two graphs is nondeterministic depending
on which one is encountered first by the just-in-time compiler.
快速后端¶
集成提供更高性能的自定义后端也很容易,我们将使用 optimize_for_inference 集成一个真正的后端
def optimize_for_inference_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
scripted = torch.jit.script(gm)
return torch.jit.optimize_for_inference(scripted)
然后,您应该能够使用以下方法优化任何现有代码:
@torch.compile(backend=optimize_for_inference_compiler)
def code_to_accelerate():
...
可组合后端¶
TorchDynamo 包含许多后端,可以使用 torch._dynamo.list_backends()
列出它们。您可以使用以下代码将这些后端组合在一起
from torch._dynamo import lookup_backend
def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
try:
trt_compiled = lookup_backend("tensorrt")(gm, example_inputs)
if trt_compiled is not None:
return trt_compiled
except Exception:
pass
# first backend failed, try something else...
try:
inductor_compiled = lookup_backend("inductor")(gm, example_inputs)
if inductor_compiled is not None:
return inductor_compiled
except Exception:
pass
return gm.forward