torch.autograd.Function.forward¶
- static Function.forward(*args, **kwargs)[源]¶
定義自定義 autograd Function 的前向傳播。
所有子類都必須重寫此函式。有兩種定義前向傳播的方法
用法 1(合併前向傳播和 ctx)
@staticmethod def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: pass
它必須接受一個上下文 ctx 作為第一個引數,後面可以跟任意數量的引數(張量或其他型別)。
更多詳細資訊請參閱 合併或分離 forward() 和 setup_context()
用法 2(分離前向傳播和 ctx)
@staticmethod def forward(*args: Any, **kwargs: Any) -> Any: pass @staticmethod def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None: pass
forward 不再接受 ctx 引數。
相反,您還必須重寫
torch.autograd.Function.setup_context()靜態方法來處理ctx物件的設定。output是前向傳播的輸出,inputs是前向傳播輸入的元組(Tuple)。更多詳細資訊請參閱 擴充套件 torch.autograd
上下文可用於儲存可在反向傳播過程中檢索的任意資料。張量不應直接儲存在 ctx 上(儘管出於向後相容性目前並未強制執行)。相反,如果張量打算用於
backward(等同於vjp),則應使用ctx.save_for_backward()儲存;如果張量打算用於jvp,則應使用ctx.save_for_forward()儲存。- 返回型別