functorch.compile.aot_function¶
-
functorch.compile.
aot_function
(fn, fw_compiler, bw_compiler=None, partition_fn=<function default_partition>, decompositions=None, num_params_buffers=0, hasher_type=None, static_argnums=None, keep_inference_input_mutations=False)[source]¶ 使用 torch 派发机制追踪
fn
的前向和反向图,然后通过fw_compiler
和bw_compiler
编译生成的前向和反向图。aot_function()
提前追踪前向和反向图,并生成一个联合的前向和反向图。然后使用partition_fn
将前向和反向图分离。分区函数可用于执行诸如重新计算之类的优化。可以设置 decompositions 字典,将运算符分解为后端编译器支持的一系列核心或更简单的运算符。aot_function()
使用基于输入张量属性的编译缓存,检测是否需要重新编译。警告
此 API 处于实验阶段,可能会更改。
- 参数
fn (Callable) – 一个接受一个或多个参数的 Python 函数。必须返回一个或多个张量。
fw_compiler (Callable) – 一个 Python 函数,它接受一个具有 Aten 运算符和输入参数的 Fx 图,并返回一个在语义上等效于输入 Fx 图的可调用对象。
bw_compiler (Optional[Callable]) – 一个 Python 函数,它接受一个具有 Aten 运算符和输入参数的 Fx 图,并返回一个在语义上等效于输入 Fx 图的可调用对象。默认值:None(如果为 None,则默认为
fw_compiler
)partition_fn (Callable) – 一个 Python 函数,它接受一个联合的前向和反向图,并将它划分成单独的前向和反向图。
decompositions (Dict) – 用于定义将更大的 Aten 运算符分解成更简单或核心 Aten 运算符的字典。
- 返回值
返回一个
Callable
,它保留原始fn
的急切行为,但前向和反向图通过fw_compile
和bw_compile
编译。
aot_function()
的一个简单示例用法如下。此示例将打印函数fn
的前向和反向图>>> fn = lambda x : x.sin().cos() >>> def print_compile_fn(fx_module, args): >>> print(fx_module) >>> return fx_module >>> aot_fn = aot_function(fn, print_compile_fn) >>> x = torch.randn(4, 5, requires_grad=True) >>> aot_fn(x)