torch.func.hessian¶
- torch.func.hessian(func, argnums=0)[source]¶
透過前向-反向策略計算
func對索引argnum處引數的 Hessian。前向-反向策略(組合
jacfwd(jacrev(func)))是實現良好效能的預設選擇。也可以透過jacfwd()和jacrev()的其他組合來計算 Hessian,例如jacfwd(jacfwd(func))或jacrev(jacrev(func))。- 引數
- 返回
返回一個函式,該函式接受與
func相同的輸入,並返回func相對於argnums處引數的 Hessian。
注意
您可能會遇到此 API 報告“前向模式 AD 未為運算子 X 實現”的錯誤。如果發生這種情況,請提交 bug 報告,我們將優先處理。另一種選擇是使用
jacrev(jacrev(func)),它具有更好的運算子覆蓋範圍。對於 R^N -> R^1 函式,基本用法會得到一個 N x N 的 Hessian 矩陣
>>> from torch.func import hessian >>> def f(x): >>> return x.sin().sum() >>> >>> x = torch.randn(5) >>> hess = hessian(f)(x) # equivalent to jacfwd(jacrev(f))(x) >>> assert torch.allclose(hess, torch.diag(-x.sin()))