torch.Tensor.register_post_accumulate_grad_hook¶
- Tensor.register_post_accumulate_grad_hook(hook)[源代码][源代码]¶
注册一个在梯度累积后运行的反向传播 hook。
当张量上所有梯度累积完成后,即该张量的 .grad 字段已更新时,就会调用此 hook。梯度累积后 hook 仅适用于叶张量(即没有 .grad_fn 字段的张量)。在非叶张量上注册此 hook 将会报错!
此 hook 应具有以下签名:
hook(param: Tensor) -> None
请注意,与其他 autograd hook 不同,此 hook 是作用于需要梯度的张量本身,而不是梯度本身。该 hook 可以原地修改和访问其张量参数,包括其 .grad 字段。
此函数返回一个 handle,该 handle 有一个
handle.remove()
方法,用于从模块中移除 hook。注意
有关此 hook 何时执行以及与其他 hook 的执行顺序的更多信息,请参阅 反向传播 Hook 执行。由于此 hook 在反向传播过程中运行,它将在
no_grad
模式下运行(除非create_graph
为 True)。如果需要,您可以使用torch.enable_grad()
在 hook 内重新启用 autograd。示例
>>> v = torch.tensor([0., 0., 0.], requires_grad=True) >>> lr = 0.01 >>> # simulate a simple SGD update >>> h = v.register_post_accumulate_grad_hook(lambda p: p.add_(p.grad, alpha=-lr)) >>> v.backward(torch.tensor([1., 2., 3.])) >>> v tensor([-0.0100, -0.0200, -0.0300], requires_grad=True) >>> h.remove() # removes the hook