torch.func¶
torch.func,之前稱為「functorch」,是類似 JAX 的 PyTorch 可組合函數轉換。
注意
這個函式庫目前處於 測試 階段。這表示這些功能通常可以正常運作(除非另有說明),而且我們(PyTorch 團隊)致力於推廣這個函式庫。但是,API可能會根據使用者回饋而改變,而且我們並沒有完全涵蓋 PyTorch 的所有操作。
如果您對 API 或您希望涵蓋的使用案例有任何建議,請提出 GitHub 問題或與我們聯繫。我們很樂意瞭解您如何使用這個函式庫。
什麼是可組合函數轉換?¶
- 「函數轉換」是一種高階函數,它接受一個數值函數並返回一個計算不同數量的新函數。 
- torch.func具有自動微分轉換(- grad(f)返回一個計算- f的梯度的函數)、向量化/批次處理轉換(- vmap(f)返回一個計算一批輸入的- f的函數)等等。
- 這些函數轉換可以任意地相互組合。例如,組合 - vmap(grad(f))可以計算一個稱為「每個樣本梯度」的量,這是目前 PyTorch 無法有效計算的。