快捷方式

torch.func.hessian

torch.func.hessian(func, argnums=0)[source]

透過前向-反向策略計算 func 對索引 argnum 處引數的 Hessian。

前向-反向策略(組合 jacfwd(jacrev(func)))是實現良好效能的預設選擇。也可以透過 jacfwd()jacrev() 的其他組合來計算 Hessian,例如 jacfwd(jacfwd(func))jacrev(jacrev(func))

引數
  • func (function) – 一個 Python 函式,接受一個或多個引數(其中一個必須是 Tensor),並返回一個或多個 Tensor

  • argnums (intTuple[int]) – 可選,整數或整數元組,表示計算 Hessian 時相對於哪些引數。預設值:0。

返回

返回一個函式,該函式接受與 func 相同的輸入,並返回 func 相對於 argnums 處引數的 Hessian。

注意

您可能會遇到此 API 報告“前向模式 AD 未為運算子 X 實現”的錯誤。如果發生這種情況,請提交 bug 報告,我們將優先處理。另一種選擇是使用 jacrev(jacrev(func)),它具有更好的運算子覆蓋範圍。

對於 R^N -> R^1 函式,基本用法會得到一個 N x N 的 Hessian 矩陣

>>> from torch.func import hessian
>>> def f(x):
>>>   return x.sin().sum()
>>>
>>> x = torch.randn(5)
>>> hess = hessian(f)(x)  # equivalent to jacfwd(jacrev(f))(x)
>>> assert torch.allclose(hess, torch.diag(-x.sin()))

文件

訪問 PyTorch 的全面開發者文件

檢視文件

教程

獲取針對初學者和高階開發者的深入教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源