torch.func.functional_call¶
- torch.func.functional_call(module, parameter_and_buffer_dicts, args=None, kwargs=None, *, tie_weights=True, strict=False)[源代码]¶
通过使用提供的参数和缓冲区替换模块参数和缓冲区,对模块执行函数式调用。
注意
如果模块具有活动参数化,则在
parameter_and_buffer_dicts
参数中传递一个值,并将名称设置为常规参数名称将完全禁用参数化。如果您想将参数化函数应用于传递的值,请将键设置为{submodule_name}.parametrizations.{parameter_name}.original
。注意
如果模块对参数/缓冲区执行就地操作,这些操作将反映在
parameter_and_buffer_dicts
输入中。示例
>>> a = {'foo': torch.zeros(())} >>> mod = Foo() # does self.foo = self.foo + 1 >>> print(mod.foo) # tensor(0.) >>> functional_call(mod, a, torch.ones(())) >>> print(mod.foo) # tensor(0.) >>> print(a['foo']) # tensor(1.)
注意
如果模块具有绑定的权重,functional_call 是否尊重绑定取决于 tie_weights 标志。
示例
>>> a = {'foo': torch.zeros(())} >>> mod = Foo() # has both self.foo and self.foo_tied which are tied. Returns x + self.foo + self.foo_tied >>> print(mod.foo) # tensor(1.) >>> mod(torch.zeros(())) # tensor(2.) >>> functional_call(mod, a, torch.zeros(())) # tensor(0.) since it will change self.foo_tied too >>> functional_call(mod, a, torch.zeros(()), tie_weights=False) # tensor(1.)--self.foo_tied is not updated >>> new_a = {'foo': torch.zeros(()), 'foo_tied': torch.zeros(())} >>> functional_call(mod, new_a, torch.zeros()) # tensor(0.)
传递多个字典的示例
a = ({'weight': torch.ones(1, 1)}, {'buffer': torch.zeros(1)}) # two separate dictionaries mod = nn.Bar(1, 1) # return self.weight @ x + self.buffer print(mod.weight) # tensor(...) print(mod.buffer) # tensor(...) x = torch.randn((1, 1)) print(x) functional_call(mod, a, x) # same as x print(mod.weight) # same as before functional_call
这是一个对模型的参数应用 grad 变换的示例。
import torch import torch.nn as nn from torch.func import functional_call, grad x = torch.randn(4, 3) t = torch.randn(4, 3) model = nn.Linear(3, 3) def compute_loss(params, x, t): y = functional_call(model, params, x) return nn.functional.mse_loss(y, t) grad_weights = grad(compute_loss)(dict(model.named_parameters()), x, t)
注意
如果用户不需要在 grad 变换之外进行梯度跟踪,他们可以分离所有参数以获得更好的性能和内存使用率
示例
>>> detached_params = {k: v.detach() for k, v in model.named_parameters()} >>> grad_weights = grad(compute_loss)(detached_params, x, t) >>> grad_weights.grad_fn # None--it's not tracking gradients outside of grad
这意味着用户无法调用
grad_weight.backward()
。但是,如果他们不需要在变换之外进行 autograd 跟踪,这将减少内存使用并提高速度。- 参数
module (torch.nn.Module) – 要调用的模块
parameters_and_buffer_dicts (Dict[str, Tensor] 或 tuple of Dict[str, Tensor]) – 将在模块调用中使用的参数。如果给定字典的元组,它们必须具有不同的键,以便所有字典可以一起使用
args (Any 或 tuple) – 要传递给模块调用的参数。如果不是元组,则视为单个参数。
kwargs (dict) – 要传递给模块调用的关键字参数
tie_weights (bool, 可选) – 如果为 True,则原始模型中绑定的参数和缓冲区将在重新参数化的版本中被视为绑定。因此,如果为 True 并且为绑定的参数和缓冲区传递了不同的值,则会报错。如果为 False,则除非为两个权重传递的值相同,否则它将不尊重最初绑定的参数和缓冲区。默认值:True。
strict (bool, 可选) – 如果为 True,则传入的参数和缓冲区必须与原始模块中的参数和缓冲区匹配。因此,如果为 True 并且有任何缺失或意外的键,则会报错。默认值:False。
- 返回
调用
module
的结果。- 返回类型
Any