快捷方式

functorch.jvp

functorch.jvp(func, primals, tangents, *, strict=False, has_aux=False)[源代码]

代表雅可比向量积,返回一个元组,包含 func(*primals) 的输出以及“在 primals 处计算的 func 的雅可比矩阵”乘以 tangents 的结果。这也称为正向模式自动微分。

参数
  • func (函数) – 一个 Python 函数,它接受一个或多个参数,其中一个必须是张量,并返回一个或多个张量。

  • primals (张量) – func 的位置参数,它们必须都是张量。返回的函数也将计算对这些参数的导数。

  • tangents (张量) – 用于计算雅可比向量积的“向量”。它必须与 func 的输入具有相同的结构和大小。

  • has_aux (布尔值) – 标志指示 func 是否返回一个 (output, aux) 元组,其中第一个元素是要微分的函数的输出,第二个元素是其他不会被微分的辅助对象。默认值为 False。

返回值

返回一个 (output, jvp_out) 元组,包含在 primals 处计算的 func 的输出和雅可比向量积。如果 has_aux True,则返回一个 (output, jvp_out, aux) 元组。

注意

你可能会看到这个 API 错误消息 “forward-mode AD not implemented for operator X”。如果是这样,请提交一个 bug 报告,我们会优先处理它。

jvp 适用于你希望计算 R^1 -> R^N 函数的梯度时。

>>> from torch.func import jvp
>>> x = torch.randn([])
>>> f = lambda x: x * torch.tensor([1., 2., 3])
>>> value, grad = jvp(f, (x,), (torch.tensor(1.),))
>>> assert torch.allclose(value, f(x))
>>> assert torch.allclose(grad, torch.tensor([1., 2, 3]))

jvp() 通过传入每个输入的切线可以支持具有多个输入的函数。

>>> from torch.func import jvp
>>> x = torch.randn(5)
>>> y = torch.randn(5)
>>> f = lambda x, y: (x * y)
>>> _, output = jvp(f, (x, y), (torch.ones(5), torch.ones(5)))
>>> assert torch.allclose(output, x + y)

警告

我们已经将 functorch 集成到 PyTorch 中。作为集成的最后一步,functorch.jvp 自 PyTorch 2.0 开始已弃用,并将从 PyTorch >= 2.3 的未来版本中删除。请改用 torch.func.jvp;有关更多详细信息,请参阅 PyTorch 2.0 发行说明或 torch.func 迁移指南 https://pytorch.ac.cn/docs/master/func.migrating.html

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源