torch.autograd.Function.backward¶
- static Function.backward(ctx, *grad_outputs)[原始碼]¶
定義使用反向模式自動微分計算操作微分的公式。
此函式必須由所有子類覆蓋。(定義此函式等同於定義
vjp函式。)它必須接受一個上下文
ctx作為第一個引數,其後是與forward()返回的輸出數量相等的引數(對於 forward 函式的非張量輸出,將傳遞 None),並且它應該返回與forward()輸入數量相等的張量。每個引數是相對於給定輸出的梯度,並且每個返回值應該是相對於相應輸入的梯度。如果輸入不是張量或是不需要梯度的張量,您可以直接傳遞 None 作為該輸入的梯度。上下文可用於檢索在 forward 傳遞期間儲存的張量。它還有一個屬性
ctx.needs_input_grad,它是一個布林值元組,表示每個輸入是否需要梯度。例如,如果forward()的第一個輸入相對於輸出需要計算梯度,則backward()將具有ctx.needs_input_grad[0] = True。- 返回型別