快捷方式

functorch.grad

functorch.grad(func, argnums=0, has_aux=False)[source]

grad 运算符有助于计算 func 相对于由 argnums 指定的输入的梯度。此运算符可以嵌套以计算高阶梯度。

参数
  • func (Callable) – 一个 Python 函数,它接受一个或多个参数。必须返回一个单元素张量。如果指定 has_aux 等于 True,则函数可以返回一个单元素张量和其他辅助对象的元组:(output, aux)

  • argnums (intTuple[int]) – 指定要计算梯度的参数。argnums 可以是单个整数或整数元组。默认值:0。

  • has_aux (bool) – 标志,指示 func 返回一个张量和其他辅助对象:(output, aux)。默认值:False。

返回值

用于计算其输入梯度的函数。默认情况下,函数的输出是相对于第一个参数的梯度张量。如果指定 has_aux 等于 True,则返回梯度和输出辅助对象的元组。如果 argnums 是一个整数元组,则返回一个相对于每个 argnums 值的输出梯度元组。

使用 grad 的示例

>>> # xdoctest: +SKIP
>>> from torch.func import grad
>>> x = torch.randn([])
>>> cos_x = grad(lambda x: torch.sin(x))(x)
>>> assert torch.allclose(cos_x, x.cos())
>>>
>>> # Second-order gradients
>>> neg_sin_x = grad(grad(lambda x: torch.sin(x)))(x)
>>> assert torch.allclose(neg_sin_x, -x.sin())

当与 vmap 组合时,grad 可用于计算每个样本的梯度

>>> # xdoctest: +SKIP
>>> from torch.func import grad, vmap
>>> batch_size, feature_size = 3, 5
>>>
>>> def model(weights, feature_vec):
>>>     # Very simple linear model with activation
>>>     assert feature_vec.dim() == 1
>>>     return feature_vec.dot(weights).relu()
>>>
>>> def compute_loss(weights, example, target):
>>>     y = model(weights, example)
>>>     return ((y - target) ** 2).mean()  # MSELoss
>>>
>>> weights = torch.randn(feature_size, requires_grad=True)
>>> examples = torch.randn(batch_size, feature_size)
>>> targets = torch.randn(batch_size)
>>> inputs = (weights, examples, targets)
>>> grad_weight_per_example = vmap(grad(compute_loss), in_dims=(None, 0, 0))(*inputs)

使用 grad 以及 has_auxargnums 的示例

>>> # xdoctest: +SKIP
>>> from torch.func import grad
>>> def my_loss_func(y, y_pred):
>>>    loss_per_sample = (0.5 * y_pred - y) ** 2
>>>    loss = loss_per_sample.mean()
>>>    return loss, (y_pred, loss_per_sample)
>>>
>>> fn = grad(my_loss_func, argnums=(0, 1), has_aux=True)
>>> y_true = torch.rand(4)
>>> y_preds = torch.rand(4, requires_grad=True)
>>> out = fn(y_true, y_preds)
>>> # > output is ((grads w.r.t y_true, grads w.r.t y_preds), (y_pred, loss_per_sample))

注意

将 PyTorch torch.no_gradgrad 结合使用。

情况 1:在函数内部使用 torch.no_grad

>>> # xdoctest: +SKIP
>>> def f(x):
>>>     with torch.no_grad():
>>>         c = x ** 2
>>>     return x - c

在这种情况下,grad(f)(x) 将遵循内部的 torch.no_grad

情况 2:在 torch.no_grad 上下文管理器内部使用 grad

>>> # xdoctest: +SKIP
>>> with torch.no_grad():
>>>     grad(f)(x)

在这种情况下,grad 将遵循内部的 torch.no_grad,但不遵循外部的。这是因为 grad 是一个“函数转换”:其结果不应依赖于 f 外部的上下文管理器的结果。

警告

我们已将 functorch 集成到 PyTorch 中。作为集成的最后一步,从 PyTorch 2.0 开始,functorch.grad 已弃用,并在 PyTorch >= 2.3 的未来版本中删除。请改用 torch.func.grad;有关更多详细信息,请参阅 PyTorch 2.0 发行说明和/或 torch.func 迁移指南 https://pytorch.ac.cn/docs/master/func.migrating.html

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源