torch.Tensor.register_post_accumulate_grad_hook¶
- Tensor.register_post_accumulate_grad_hook(hook)[原始碼][原始碼]¶
註冊一個在梯度累積後執行的反向傳播 hook。
當張量上所有梯度累積完成後,即該張量的 .grad 欄位已更新時,就會呼叫此 hook。梯度累積後 hook 僅適用於葉張量(即沒有 .grad_fn 欄位的張量)。在非葉張量上註冊此 hook 將會報錯!
此 hook 應具有以下簽名:
hook(param: Tensor) -> None
請注意,與其他 autograd hook 不同,此 hook 是作用於需要梯度的張量本身,而不是梯度本身。該 hook 可以原地修改和訪問其張量引數,包括其 .grad 欄位。
此函式返回一個 handle,該 handle 有一個
handle.remove()方法,用於從模組中移除 hook。注意
有關此 hook 何時執行以及與其他 hook 的執行順序的更多資訊,請參閱 反向傳播 Hook 執行。由於此 hook 在反向傳播過程中執行,它將在
no_grad模式下執行(除非create_graph為 True)。如果需要,您可以使用torch.enable_grad()在 hook 內重新啟用 autograd。示例
>>> v = torch.tensor([0., 0., 0.], requires_grad=True) >>> lr = 0.01 >>> # simulate a simple SGD update >>> h = v.register_post_accumulate_grad_hook(lambda p: p.add_(p.grad, alpha=-lr)) >>> v.backward(torch.tensor([1., 2., 3.])) >>> v tensor([-0.0100, -0.0200, -0.0300], requires_grad=True) >>> h.remove() # removes the hook