torch.Tensor.register_post_accumulate_grad_hook¶
- Tensor.register_post_accumulate_grad_hook(hook)[源代码][源代码]¶
注册一个在梯度累积后运行的反向钩子。
该钩子将在张量的所有梯度累积完成后被调用,这意味着该张量的 .grad 字段已被更新。后累积梯度钩子仅适用于叶张量(没有 .grad_fn 字段的张量)。在非叶张量上注册此钩子会报错!
钩子应具有以下签名
hook(param: Tensor) -> None
请注意,与其他自动求导钩子不同,此钩子作用于需要梯度的张量,而不是梯度本身。钩子可以就地修改和访问其 Tensor 参数,包括其 .grad 字段。
此函数返回一个带有
handle.remove()
方法的句柄,该方法从模块中移除钩子。注意
有关此钩子的执行时间以及其相对于其他钩子的执行顺序的更多信息,请参阅 反向钩子执行。由于此钩子在反向传播期间运行,它将在 no_grad 模式下运行(除非 create_graph 为 True)。如果需要,您可以使用 torch.enable_grad() 在钩子内重新启用自动求导。
示例
>>> v = torch.tensor([0., 0., 0.], requires_grad=True) >>> lr = 0.01 >>> # simulate a simple SGD update >>> h = v.register_post_accumulate_grad_hook(lambda p: p.add_(p.grad, alpha=-lr)) >>> v.backward(torch.tensor([1., 2., 3.])) >>> v tensor([-0.0100, -0.0200, -0.0300], requires_grad=True) >>> h.remove() # removes the hook