快捷方式

functorch.functionalize

functorch.functionalize(func, *, remove='mutations')[源代码]

functionalize 是一种变换,可用于从函数中删除(中间)变异和别名,同时保留函数的语义。

functionalize(func) 返回一个与 func 语义相同的新函数,但所有中间变异都被删除。对中间张量执行的每个就地操作:intermediate.foo_() 将被其非就地等效项替换:intermediate_updated = intermediate.foo()

functionalize 对于将 PyTorch 程序发送到无法轻松表示变异或别名运算符的后端或编译器很有用。

参数
  • func (Callable) – 一个接受一个或多个参数的 Python 函数。

  • remove (str) – 一个可选的字符串参数,取值为 ‘mutations’ 或 ‘mutations_and_views’。如果传递 ‘mutations’,则所有变异运算符将被其非变异等效项替换。如果传递 ‘mutations_and_views’,则此外,所有别名运算符将被其非别名等效项替换。默认值:‘mutations’。

返回值

返回一个新的“功能化”函数。它接受与 func 相同的输入,并具有相同的行为,但函数中对中间张量执行的任何变异(以及可选的别名)都将被删除。

functionalize 还将删除对函数输入执行的变异(和视图)。但是,为了保留语义,functionalize 将在变换完成运行后“修复”变异,通过检测是否应该对任何张量输入进行变异,并在必要时将新数据复制回输入。

示例

>>> # xdoctest: +SKIP
>>> import torch
>>> from torch.fx.experimental.proxy_tensor import make_fx
>>> from torch.func import functionalize
>>>
>>> # A function that uses mutations and views, but only on intermediate tensors.
>>> def f(a):
...     b = a + 1
...     c = b.view(-1)
...     c.add_(1)
...     return b
...
>>> inpt = torch.randn(2)
>>>
>>> out1 = f(inpt)
>>> out2 = functionalize(f)(inpt)
>>>
>>> # semantics are the same (outputs are equivalent)
>>> print(torch.allclose(out1, out2))
True
>>>
>>> f_traced = make_fx(f)(inpt)
>>> f_no_mutations_traced = make_fx(functionalize(f))(inpt)
>>> f_no_mutations_and_views_traced = make_fx(functionalize(f, remove='mutations_and_views'))(inpt)
>>>
>>> print(f_traced.code)



def forward(self, a_1):
    add = torch.ops.aten.add(a_1, 1);  a_1 = None
    view = torch.ops.aten.view(add, [-1])
    add_ = torch.ops.aten.add_(view, 1);  view = None
    return add

>>> print(f_no_mutations_traced.code)



def forward(self, a_1):
    add = torch.ops.aten.add(a_1, 1);  a_1 = None
    view = torch.ops.aten.view(add, [-1]);  add = None
    add_1 = torch.ops.aten.add(view, 1);  view = None
    view_1 = torch.ops.aten.view(add_1, [2]);  add_1 = None
    return view_1

>>> print(f_no_mutations_and_views_traced.code)



def forward(self, a_1):
    add = torch.ops.aten.add(a_1, 1);  a_1 = None
    view_copy = torch.ops.aten.view_copy(add, [-1]);  add = None
    add_1 = torch.ops.aten.add(view_copy, 1);  view_copy = None
    view_copy_1 = torch.ops.aten.view_copy(add_1, [2]);  add_1 = None
    return view_copy_1


>>> # A function that mutates its input tensor
>>> def f(a):
...     b = a.view(-1)
...     b.add_(1)
...     return a
...
>>> f_no_mutations_and_views_traced = make_fx(functionalize(f, remove='mutations_and_views'))(inpt)
>>> #
>>> # All mutations and views have been removed,
>>> # but there is an extra copy_ in the graph to correctly apply the mutation to the input
>>> # after the function has completed.
>>> print(f_no_mutations_and_views_traced.code)



def forward(self, a_1):
    view_copy = torch.ops.aten.view_copy(a_1, [-1])
    add = torch.ops.aten.add(view_copy, 1);  view_copy = None
    view_copy_1 = torch.ops.aten.view_copy(add, [2]);  add = None
    copy_ = torch.ops.aten.copy_(a_1, view_copy_1);  a_1 = None
    return view_copy_1
有一些 functionalize 的“故障模式”值得一提
  1. 与其他 torch.func 变换一样,functionalize() 不适用于直接使用 .backward() 的函数。torch.autograd.grad 也是如此。如果你想使用 autograd,你可以直接使用 functionalize(grad(f)) 计算梯度。

  2. 与其他 torch.func 变换一样,functionalize() 不适用于全局状态。如果你对接受非本地状态的视图/变异的函数调用 functionalize(f),功能化将简单地执行无操作并将视图/变异调用直接传递给后端。解决此问题的一种方法是确保将任何非本地状态创建包装到一个更大的函数中,然后对该函数调用 functionalize。

  3. resize_() 有一些限制:只有当被调整大小的张量不是视图时,functionalize 才能在使用 resize_() 的程序上运行。

  4. as_strided() 有一些限制:functionalize 不会在导致具有重叠内存的张量的 as_strided() 调用上运行。

最后,理解功能化的一个有用的心理模型是,大多数用户 PyTorch 程序都是用公共 torch API 编写的。当执行时,torch 运算符通常分解为我们内部的 C++“ATen”API。功能化的逻辑完全发生在 ATen 的级别。功能化知道如何获取 ATen 中的每个别名运算符,并将其映射到其非别名等效项(例如 tensor.view({-1}) -> at::view_copy(tensor, {-1})),以及如何获取 ATen 中的每个变异运算符,并将其映射到其非变异等效项(例如 tensor.add_(1) -> at::add(tensor, -1)),同时跟踪别名和变异,以便知道何时进行修复。有关哪些 ATen 运算符是别名或变异的信息来自 https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml.

警告

我们将 functorch 集成到 PyTorch 中。作为集成的最后一步,functorch.functionalize 从 PyTorch 2.0 开始已弃用,将在 PyTorch >= 2.3 的未来版本中删除。请改用 torch.func.functionalize;有关更多详细信息,请参阅 PyTorch 2.0 发行说明和/或 torch.func 迁移指南 https://pytorch.ac.cn/docs/master/func.migrating.html

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取适用于初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得解答

查看资源