捷徑

torch.func

torch.func,之前稱為「functorch」,是類似 JAX 的 PyTorch 可組合函數轉換。

注意

這個函式庫目前處於 測試 階段。這表示這些功能通常可以正常運作(除非另有說明),而且我們(PyTorch 團隊)致力於推廣這個函式庫。但是,API可能會根據使用者回饋而改變,而且我們並沒有完全涵蓋 PyTorch 的所有操作。

如果您對 API 或您希望涵蓋的使用案例有任何建議,請提出 GitHub 問題或與我們聯繫。我們很樂意瞭解您如何使用這個函式庫。

什麼是可組合函數轉換?

  • 「函數轉換」是一種高階函數,它接受一個數值函數並返回一個計算不同數量的新函數。

  • torch.func 具有自動微分轉換(grad(f) 返回一個計算 f 的梯度的函數)、向量化/批次處理轉換(vmap(f) 返回一個計算一批輸入的 f 的函數)等等。

  • 這些函數轉換可以任意地相互組合。例如,組合 vmap(grad(f)) 可以計算一個稱為「每個樣本梯度」的量,這是目前 PyTorch 無法有效計算的。

為什麼要使用可組合函數轉換?

目前在 PyTorch 中,有許多使用案例難以實現

  • 計算每個樣本梯度(或其他每個樣本的量)

  • 在單一機器上執行模型的集成

  • 在 MAML 的內迴圈中有效地將任務批次化

  • 有效地計算雅可比矩陣和海塞矩陣

  • 有效地計算批次的雅可比矩陣和海塞矩陣

組合 vmap()grad()vjp() 轉換讓我們無需為每個子系統設計獨立的系統,就能夠表達上述內容。這種可組合函數轉換的概念來自於 JAX 框架

文件

存取 PyTorch 的完整開發人員文件

查看文件

教學

取得針對初學者和進階開發人員的深入教學

查看教學

資源

尋找開發資源並獲得問題解答

查看資源