快捷方式

torch.autograd.Function.forward

static Function.forward(*args, **kwargs)[源]

定義自定義 autograd Function 的前向傳播。

所有子類都必須重寫此函式。有兩種定義前向傳播的方法

用法 1(合併前向傳播和 ctx)

@staticmethod
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
    pass

用法 2(分離前向傳播和 ctx)

@staticmethod
def forward(*args: Any, **kwargs: Any) -> Any:
    pass

@staticmethod
def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:
    pass
  • forward 不再接受 ctx 引數。

  • 相反,您還必須重寫 torch.autograd.Function.setup_context() 靜態方法來處理 ctx 物件的設定。 output 是前向傳播的輸出,inputs 是前向傳播輸入的元組(Tuple)。

  • 更多詳細資訊請參閱 擴充套件 torch.autograd

上下文可用於儲存可在反向傳播過程中檢索的任意資料。張量不應直接儲存在 ctx 上(儘管出於向後相容性目前並未強制執行)。相反,如果張量打算用於 backward(等同於 vjp),則應使用 ctx.save_for_backward() 儲存;如果張量打算用於 jvp,則應使用 ctx.save_for_forward() 儲存。

返回型別

Any

文件

查閱 PyTorch 的全面開發者文件

檢視文件

教程

獲取針對初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源