torch.func.linearize¶
- torch.func.linearize(func, *primals)[source]¶
返回
primals處的func值以及在primals處的線性近似。- 引數
func (Callable) – 一個接受一個或多個引數的 Python 函式。
primals (Tensors) –
func的位置引數,必須全部是 Tensor。這些是函式進行線性近似時的值。
- 返回值
返回一個
(output, jvp_fn)元組,包含應用於primals的func輸出,以及一個計算在primals處求值的func的 jvp 的函式。- 返回型別
如果在
primals處多次計算 jvp,則 linearize 會很有用。但是,為此,linearize 會儲存中間計算結果,並且比直接應用 jvp 需要更高的記憶體。因此,如果所有tangents都已知,計算 vmap(jvp) 可能比使用 linearize 更高效。注意
linearize 會兩次評估
func。請提交一個 issue 以實現單次評估版本。- 示例:
>>> import torch >>> from torch.func import linearize >>> def fn(x): ... return x.sin() ... >>> output, jvp_fn = linearize(fn, torch.zeros(3, 3)) >>> jvp_fn(torch.ones(3, 3)) tensor([[1., 1., 1.], [1., 1., 1.], [1., 1., 1.]]) >>>