torch.nn.utils.stateless.functional_call¶
- torch.nn.utils.stateless.functional_call(module, parameters_and_buffers, args=None, kwargs=None, *, tie_weights=True, strict=False)[源代码][源代码]¶
通过用提供的参数和缓冲区替换模块的参数和缓冲区来执行模块上的函数式调用。
警告
此 API 已在 PyTorch 2.0 中弃用,并将在未来版本中移除。请改用
torch.func.functional_call()
,它是此 API 的直接替代品。注意
如果模块具有活跃的参数化 (parametrization),在
parameters_and_buffers
参数中传入一个名称设置为常规参数名称的值将完全禁用该参数化。如果要将参数化函数应用于传入的值,请将键设置为{submodule_name}.parametrizations.{parameter_name}.original
。注意
如果模块对参数/缓冲区执行原地 (in-place) 操作,这些操作将反映在 parameters_and_buffers 输入中。
示例
>>> a = {'foo': torch.zeros(())} >>> mod = Foo() # does self.foo = self.foo + 1 >>> print(mod.foo) # tensor(0.) >>> functional_call(mod, a, torch.ones(())) >>> print(mod.foo) # tensor(0.) >>> print(a['foo']) # tensor(1.)
注意
如果模块具有绑定权重 (tied weights),则 functional_call 是否遵循绑定取决于 tie_weights 标志。
示例
>>> a = {'foo': torch.zeros(())} >>> mod = Foo() # has both self.foo and self.foo_tied which are tied. Returns x + self.foo + self.foo_tied >>> print(mod.foo) # tensor(1.) >>> mod(torch.zeros(())) # tensor(2.) >>> functional_call(mod, a, torch.zeros(())) # tensor(0.) since it will change self.foo_tied too >>> functional_call(mod, a, torch.zeros(()), tie_weights=False) # tensor(1.)--self.foo_tied is not updated >>> new_a = {'foo': torch.zeros(()), 'foo_tied': torch.zeros(())} >>> functional_call(mod, new_a, torch.zeros()) # tensor(0.)
- 参数
module (torch.nn.Module) – 要调用的模块
parameters_and_buffers (dict of str and Tensor) – 将在模块调用中使用的参数。
args (Any or tuple) – 要传递给模块调用的参数。如果不是元组,则视为单个参数。
kwargs (dict) – 要传递给模块调用的关键字参数
tie_weights (bool, optional) – 如果为 True,则原始模型中绑定的参数和缓冲区在重新参数化版本中也将被视为绑定。因此,如果为 True 且为绑定的参数和缓冲区传入了不同的值,将引发错误。如果为 False,则不会遵循原始绑定的参数和缓冲区,除非为两个权重传入的值相同。默认值:True。
strict (bool, optional) – 如果为 True,则传入的参数和缓冲区必须与原始模块中的参数和缓冲区匹配。因此,如果为 True 且存在任何缺失或意外的键,将引发错误。默认值:False。
- 返回值
调用
module
的结果。- 返回值类型
Any