快捷方式

torch.func.jvp

torch.func.jvp(func, primals, tangents, *, strict=False, has_aux=False)[source]

代表著雅可比向量積,返回一個包含 func(*primals) 的輸出以及“在 primals 處計算的 func 的雅可比”乘以 tangents 的元組。這也稱為前向自動微分。

引數
  • func (function) – 一個 Python 函式,接受一個或多個引數(其中一個必須是 Tensor),並返回一個或多個 Tensor

  • primals (Tensors) – func 的位置引數,必須都是 Tensor。返回的函式也將計算相對於這些引數的導數。

  • tangents (Tensors) – 計算雅可比向量積的“向量”。必須與 func 的輸入具有相同的結構和尺寸。

  • has_aux (bool) – 標誌,指示 func 返回一個 (output, aux) 元組,其中第一個元素是要微分的函式輸出,第二個元素是其他不會被微分的輔助物件。預設值:False。

返回值

返回一個包含在 primals 處計算的 func 的輸出以及雅可比向量積的 (output, jvp_out) 元組。如果 has_aux True,則改為返回一個 (output, jvp_out, aux) 元組。

注意

您可能會看到此 API 報錯“operator X 未實現前向自動微分”。如果發生這種情況,請提交錯誤報告,我們將優先處理。

當您希望計算函式 R^1 -> R^N 的梯度時,jvp 非常有用

>>> from torch.func import jvp
>>> x = torch.randn([])
>>> f = lambda x: x * torch.tensor([1., 2., 3])
>>> value, grad = jvp(f, (x,), (torch.tensor(1.),))
>>> assert torch.allclose(value, f(x))
>>> assert torch.allclose(grad, torch.tensor([1., 2, 3]))

jvp() 可以透過為每個輸入傳遞 tangents 來支援具有多個輸入的函式

>>> from torch.func import jvp
>>> x = torch.randn(5)
>>> y = torch.randn(5)
>>> f = lambda x, y: (x * y)
>>> _, output = jvp(f, (x, y), (torch.ones(5), torch.ones(5)))
>>> assert torch.allclose(output, x + y)

文件

訪問 PyTorch 的全面開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並解答您的疑問

檢視資源