快捷方式

torch.func

torch.func,以前稱為“functorch”,是 PyTorch 中類似 JAX 的可組合函式變換。

注意

該庫目前處於 測試階段(beta)。這意味著其功能通常可以工作(除非另有說明),並且我們(PyTorch 團隊)致力於推進該庫。然而,API 可能會根據使用者反饋進行更改,而且我們尚未完全覆蓋所有 PyTorch 操作。

如果您對 API 或希望涵蓋的用例有任何建議,請在 GitHub 上提交 issue 或聯絡我們。我們很樂意瞭解您如何使用該庫。

什麼是可組合函式變換?

  • “函式變換”是一種高階函式,它接受一個數值函式,並返回一個計算不同量的新函式。

  • torch.func 包含自動微分變換(grad(f) 返回計算 f 梯度的函式)、向量化/批處理變換(vmap(f) 返回在輸入批次上計算 f 的函式)等。

  • 這些函式變換可以任意組合。例如,組合 vmap(grad(f)) 可以計算稱為“每樣本梯度”的量,而當前的 PyTorch 尚無法高效計算此量。

為什麼選擇可組合函式變換?

目前在 PyTorch 中實現許多用例比較棘手

  • 計算每樣本梯度(或其他每樣本量)

  • 在單臺機器上執行模型整合

  • 在 MAML 內迴圈中高效地將任務批次處理

  • 高效計算 Jacobian 和 Hessian

  • 高效計算批次 Jacobian 和 Hessian

組合 vmap()grad()vjp() 變換,使我們無需為每個用例設計單獨的子系統即可實現上述功能。這種可組合函式變換的思想源自 JAX 框架

文件

查閱全面的 PyTorch 開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深入教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源