快捷方式

torch.compiler.substitute_in_graph

torch.compiler.substitute_in_graph(original_fn, *, can_constant_fold_through=False, skip_signature_check=False)[源代码]

为函数(通常是来自 C 扩展的 C 函数)注册一个 polyfill 处理程序,以便在图中内联原始函数时使用该处理程序代替原始函数。

注意

仅在内联原始函数时才使用 polyfill 处理程序。当直接调用原始函数时,不会使用它。在急切模式下,装饰后的函数调用高性能 C 函数而不是 polyfill 处理程序。

polyfill 处理程序是一个函数,在内联原始函数时,它将被调用以代替原始函数。polyfill 处理程序应具有与原始函数相同的签名和行为。

参数
  • original_fn (callable) – 要为其注册 polyfill 处理程序的原始函数,通常是 C 函数。

  • can_constant_fold_through (bool, 可选) – polyfill 处理程序是否可以进行常量折叠。也就是说,如果 polyfill 处理程序是纯函数,并且其参数是常量,则 polyfill 处理程序的结果可以在编译期间进行常量折叠。默认为 False

  • skip_signature_check (bool, 可选) – 是否跳过原始函数和 polyfill 处理程序之间的签名检查。默认为 False

返回值

一个装饰器,为原始函数注册 polyfill 处理程序。

返回类型

Callable[[_F], _F]

示例

>>> import operator
>>> operator.indexOf([1, 2, 3, 4, 5], 3)
2
>>> torch.compile(operator.indexOf, fullgraph=True)([1, 2, 3, 4, 5], 3)
... # xdoctest: +SKIP("Long tracebacks")
Traceback (most recent call last):
...
torch._dynamo.exc.Unsupported: ...

>>> @torch.compiler.substitute_in_graph(operator.indexOf)
... def indexOf(a, b, /):
...     for i, item in enumerate(a):
...         if item is b or item == b:
...             return i
...     raise ValueError("sequence.index(x): x not in sequence")
>>>
>>> torch.compile(operator.indexOf, fullgraph=True)([1, 2, 3, 4, 5], 3)
2

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取适合初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源