torch.func.jvp¶
- torch.func.jvp(func, primals, tangents, *, strict=False, has_aux=False)[source]¶
代表着雅可比向量积,返回一个包含 func(*primals) 的输出以及“在
primals
处计算的func
的雅可比”乘以tangents
的元组。这也称为前向自动微分。- 参数
func (function) – 一个 Python 函数,接受一个或多个参数(其中一个必须是 Tensor),并返回一个或多个 Tensor
primals (Tensors) –
func
的位置参数,必须都是 Tensor。返回的函数也将计算相对于这些参数的导数。tangents (Tensors) – 计算雅可比向量积的“向量”。必须与
func
的输入具有相同的结构和尺寸。has_aux (bool) – 标志,指示
func
返回一个(output, aux)
元组,其中第一个元素是要微分的函数输出,第二个元素是其他不会被微分的辅助对象。默认值:False。
- 返回值
返回一个包含在
primals
处计算的func
的输出以及雅可比向量积的(output, jvp_out)
元组。如果has_aux 为 True
,则改为返回一个(output, jvp_out, aux)
元组。
注意
您可能会看到此 API 报错“operator X 未实现前向自动微分”。如果发生这种情况,请提交错误报告,我们将优先处理。
当您希望计算函数 R^1 -> R^N 的梯度时,jvp 非常有用
>>> from torch.func import jvp >>> x = torch.randn([]) >>> f = lambda x: x * torch.tensor([1., 2., 3]) >>> value, grad = jvp(f, (x,), (torch.tensor(1.),)) >>> assert torch.allclose(value, f(x)) >>> assert torch.allclose(grad, torch.tensor([1., 2, 3]))
jvp()
可以通过为每个输入传递 tangents 来支持具有多个输入的函数>>> from torch.func import jvp >>> x = torch.randn(5) >>> y = torch.randn(5) >>> f = lambda x, y: (x * y) >>> _, output = jvp(f, (x, y), (torch.ones(5), torch.ones(5))) >>> assert torch.allclose(output, x + y)