快捷方式

torch.func.linearize

torch.func.linearize(func, *primals)[source]

返回 primals 处的 func 值以及在 primals 处的线性近似。

参数
  • func (Callable) – 一个接受一个或多个参数的 Python 函数。

  • primals (Tensors) – func 的位置参数,必须全部是 Tensor。这些是函数进行线性近似时的值。

返回值

返回一个 (output, jvp_fn) 元组,包含应用于 primalsfunc 输出,以及一个计算在 primals 处求值的 func 的 jvp 的函数。

返回类型

tuple[Any, Callable]

如果在 primals 处多次计算 jvp,则 linearize 会很有用。但是,为此,linearize 会保存中间计算结果,并且比直接应用 jvp 需要更高的内存。因此,如果所有 tangents 都已知,计算 vmap(jvp) 可能比使用 linearize 更高效。

注意

linearize 会两次评估 func。请提交一个 issue 以实现单次评估版本。

示例:
>>> import torch
>>> from torch.func import linearize
>>> def fn(x):
...     return x.sin()
...
>>> output, jvp_fn = linearize(fn, torch.zeros(3, 3))
>>> jvp_fn(torch.ones(3, 3))
tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])
>>>

© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,主题由 Read the Docs 提供。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得解答

查看资源