快捷方式

functorch.compile.aot_function

functorch.compile.aot_function(fn, fw_compiler, bw_compiler=None, partition_fn=<function default_partition>, decompositions=None, num_params_buffers=0, hasher_type=None, static_argnums=None, keep_inference_input_mutations=False)[source]

使用 torch 派发机制追踪 fn 的前向和反向图,然后通过 fw_compilerbw_compiler 编译生成的前向和反向图。

aot_function() 提前追踪前向和反向图,并生成一个联合的前向和反向图。然后使用 partition_fn 将前向和反向图分离。分区函数可用于执行诸如重新计算之类的优化。可以设置 decompositions 字典,将运算符分解为后端编译器支持的一系列核心或更简单的运算符。

aot_function() 使用基于输入张量属性的编译缓存,检测是否需要重新编译。

警告

此 API 处于实验阶段,可能会更改。

参数
  • fn (Callable) – 一个接受一个或多个参数的 Python 函数。必须返回一个或多个张量。

  • fw_compiler (Callable) – 一个 Python 函数,它接受一个具有 Aten 运算符和输入参数的 Fx 图,并返回一个在语义上等效于输入 Fx 图的可调用对象。

  • bw_compiler (Optional[Callable]) – 一个 Python 函数,它接受一个具有 Aten 运算符和输入参数的 Fx 图,并返回一个在语义上等效于输入 Fx 图的可调用对象。默认值:None(如果为 None,则默认为 fw_compiler

  • partition_fn (Callable) – 一个 Python 函数,它接受一个联合的前向和反向图,并将它划分成单独的前向和反向图。

  • decompositions (Dict) – 用于定义将更大的 Aten 运算符分解成更简单或核心 Aten 运算符的字典。

返回值

返回一个 Callable,它保留原始 fn 的急切行为,但前向和反向图通过 fw_compilebw_compile 编译。

aot_function() 的一个简单示例用法如下。此示例将打印函数 fn 的前向和反向图

>>> fn = lambda x : x.sin().cos()
>>> def print_compile_fn(fx_module, args):
>>>     print(fx_module)
>>>     return fx_module
>>> aot_fn = aot_function(fn, print_compile_fn)
>>> x = torch.randn(4, 5, requires_grad=True)
>>> aot_fn(x)

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获得针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源