快捷方式

torch.Tensor.register_post_accumulate_grad_hook

Tensor.register_post_accumulate_grad_hook(hook)[原始碼][原始碼]

註冊一個在梯度累積後執行的反向傳播 hook。

當張量上所有梯度累積完成後,即該張量的 .grad 欄位已更新時,就會呼叫此 hook。梯度累積後 hook 僅適用於葉張量(即沒有 .grad_fn 欄位的張量)。在非葉張量上註冊此 hook 將會報錯!

此 hook 應具有以下簽名:

hook(param: Tensor) -> None

請注意,與其他 autograd hook 不同,此 hook 是作用於需要梯度的張量本身,而不是梯度本身。該 hook 可以原地修改和訪問其張量引數,包括其 .grad 欄位。

此函式返回一個 handle,該 handle 有一個 handle.remove() 方法,用於從模組中移除 hook。

注意

有關此 hook 何時執行以及與其他 hook 的執行順序的更多資訊,請參閱 反向傳播 Hook 執行。由於此 hook 在反向傳播過程中執行,它將在 no_grad 模式下執行(除非 create_graph 為 True)。如果需要,您可以使用 torch.enable_grad() 在 hook 內重新啟用 autograd。

示例

>>> v = torch.tensor([0., 0., 0.], requires_grad=True)
>>> lr = 0.01
>>> # simulate a simple SGD update
>>> h = v.register_post_accumulate_grad_hook(lambda p: p.add_(p.grad, alpha=-lr))
>>> v.backward(torch.tensor([1., 2., 3.]))
>>> v
tensor([-0.0100, -0.0200, -0.0300], requires_grad=True)

>>> h.remove()  # removes the hook

文件

查閱 PyTorch 的全面開發者文件

檢視文件

教程

獲取針對初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源