torch.func.linearize¶
- torch.func.linearize(func, *primals)¶
返回在
primals
处的func
值和在primals
处的线性近似。- 参数
func (Callable) – 一个接受一个或多个参数的 Python 函数。
primals (Tensors) –
func
的位置参数,这些参数必须全部是张量。这些是在其中线性近似函数的值。
- 返回值
返回一个
(output, jvp_fn)
元组,其中包含应用于primals
的func
的输出,以及一个计算在primals
处评估的func
的 jvp 的函数。- 返回类型
如果要在
primals
处多次计算 jvp,则 linearize 很有用。但是,为了实现这一点,linearize 保存了中间计算,并且比直接应用 jvp 具有更高的内存需求。因此,如果所有tangents
都已知,使用 vmap(jvp) 计算可能比使用 linearize 更有效。注意
linearize 对
func
进行了两次评估。请为使用单次评估的实现提交问题。- 示例:
>>> import torch >>> from torch.func import linearize >>> def fn(x): ... return x.sin() ... >>> output, jvp_fn = linearize(fn, torch.zeros(3, 3)) >>> jvp_fn(torch.ones(3, 3)) tensor([[1., 1., 1.], [1., 1., 1.], [1., 1., 1.]]) >>>