快捷方式

torch.func.linearize

torch.func.linearize(func, *primals)[source]

返回 primals 處的 func 值以及在 primals 處的線性近似。

引數
  • func (Callable) – 一個接受一個或多個引數的 Python 函式。

  • primals (Tensors) – func 的位置引數,必須全部是 Tensor。這些是函式進行線性近似時的值。

返回值

返回一個 (output, jvp_fn) 元組,包含應用於 primalsfunc 輸出,以及一個計算在 primals 處求值的 func 的 jvp 的函式。

返回型別

tuple[Any, Callable]

如果在 primals 處多次計算 jvp,則 linearize 會很有用。但是,為此,linearize 會儲存中間計算結果,並且比直接應用 jvp 需要更高的記憶體。因此,如果所有 tangents 都已知,計算 vmap(jvp) 可能比使用 linearize 更高效。

注意

linearize 會兩次評估 func。請提交一個 issue 以實現單次評估版本。

示例:
>>> import torch
>>> from torch.func import linearize
>>> def fn(x):
...     return x.sin()
...
>>> output, jvp_fn = linearize(fn, torch.zeros(3, 3))
>>> jvp_fn(torch.ones(3, 3))
tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])
>>>

© 版權所有 PyTorch 貢獻者。

使用 Sphinx 構建,主題由 Read the Docs 提供。

文件

訪問 PyTorch 的全面開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲得解答

檢視資源