torch.func¶
torch.func,以前稱為“functorch”,是 PyTorch 中類似 JAX 的可組合函式變換。
注意
該庫目前處於 測試階段(beta)。這意味著其功能通常可以工作(除非另有說明),並且我們(PyTorch 團隊)致力於推進該庫。然而,API 可能會根據使用者反饋進行更改,而且我們尚未完全覆蓋所有 PyTorch 操作。
如果您對 API 或希望涵蓋的用例有任何建議,請在 GitHub 上提交 issue 或聯絡我們。我們很樂意瞭解您如何使用該庫。
什麼是可組合函式變換?¶
“函式變換”是一種高階函式,它接受一個數值函式,並返回一個計算不同量的新函式。
torch.func包含自動微分變換(grad(f)返回計算f梯度的函式)、向量化/批處理變換(vmap(f)返回在輸入批次上計算f的函式)等。這些函式變換可以任意組合。例如,組合
vmap(grad(f))可以計算稱為“每樣本梯度”的量,而當前的 PyTorch 尚無法高效計算此量。