torch.Tensor.register_hook¶
- Tensor.register_hook(hook)[源][源]¶
註冊一個反向鉤子。
每次計算相對於該 Tensor 的梯度時,都會呼叫此鉤子。此鉤子應具有以下簽名:
hook(grad) -> Tensor or None
鉤子不應修改其引數,但可以選擇返回一個新的梯度,該梯度將用於替換
grad。此函式返回一個控制代碼,該控制代碼包含一個方法
handle.remove(),用於從模組中移除該鉤子。注意
關於此鉤子的執行時間以及其相對於其他鉤子的執行順序的更多資訊,請參閱反向鉤子執行。
示例
>>> v = torch.tensor([0., 0., 0.], requires_grad=True) >>> h = v.register_hook(lambda grad: grad * 2) # double the gradient >>> v.backward(torch.tensor([1., 2., 3.])) >>> v.grad 2 4 6 [torch.FloatTensor of size (3,)] >>> h.remove() # removes the hook