快捷方式

torch.func.grad_and_value

torch.func.grad_and_value(func, argnums=0, has_aux=False)[source]

返回一個函式,用於計算梯度和原始(或前向)計算的元組。

引數
  • func (Callable) – 一個接受一個或多個引數的 Python 函式。必須返回一個單元素 Tensor。如果指定了 has_aux 等於 True,函式可以返回一個包含單元素 Tensor 和其他輔助物件的元組: (output, aux)

  • argnums (intTuple[int]) – 指定要計算梯度的引數。argnums 可以是一個整數或整數元組。預設值:0。

  • has_aux (bool) – 一個標誌,指示 func 返回一個 tensor 和其他輔助物件: (output, aux)。預設值:False。

返回

返回一個函式,用於計算相對於輸入的梯度和前向計算的元組。預設情況下,該函式的輸出是相對於第一個引數的梯度 tensor(s) 和原始計算的元組。如果指定了 has_aux 等於 True,則返回梯度元組和包含輸出輔助物件的前向計算元組。如果 argnums 是一個整數元組,則返回一個元組,該元組包含相對於每個 argnums 值的輸出梯度元組以及前向計算。

返回型別

Callable

參見 grad() 獲取示例

文件

查閱 PyTorch 全面開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源