從 functorch 遷移到 torch.func¶
torch.func,之前稱為 “functorch”,是 PyTorch 中類似 JAX 的可組合函式變換。
functorch 最初是 pytorch/functorch 倉庫中的一個獨立庫。我們的目標始終是將 functorch 直接合併到 PyTorch 主幹中,並將其作為核心 PyTorch 庫提供。
作為合併主幹的最後一步,我們決定從一個頂級包(functorch)遷移到成為 PyTorch 的一部分,以反映函式變換是如何直接整合到 PyTorch 核心中的。從 PyTorch 2.0 開始,我們棄用了 import functorch,並要求使用者遷移到我們將繼續維護的最新 API。import functorch 將保留幾個版本,以保持向後相容性。
函式變換¶
以下 API 可以直接替換以下 functorch API。它們完全向後相容。
functorch API |
PyTorch API (截至 PyTorch 2.0) |
|---|---|
functorch.vmap |
|
functorch.grad |
|
functorch.vjp |
|
functorch.jvp |
|
functorch.jacrev |
|
functorch.jacfwd |
|
functorch.hessian |
|
functorch.functionalize |
此外,如果您正在使用 torch.autograd.functional API,請嘗試使用相應的 torch.func 替代項。torch.func 函式變換在許多情況下更具可組合性和更高效能。
torch.autograd.functional API |
torch.func API (截至 PyTorch 2.0) |
|---|---|
NN 模組工具¶
我們更改了應用於 NN 模組的函式變換 API,使其更符合 PyTorch 的設計理念。新的 API 有所不同,請仔細閱讀本節。
functorch.make_functional¶
torch.func.functional_call() 替代了 functorch.make_functional 和 functorch.make_functional_with_buffers。然而,它不能完全直接替換。
如果時間緊急,您可以使用此 Gist 中的輔助函式來模擬 functorch.make_functional 和 functorch.make_functional_with_buffers 的行為。我們建議直接使用 torch.func.functional_call(),因為它是一個更明確、更靈活的 API。
具體來說,functorch.make_functional 返回一個函式式模組和引數。該函式式模組接受引數和模型輸入作為引數。torch.func.functional_call() 允許使用新的引數、緩衝區和輸入呼叫現有模組的前向傳遞。
這裡有一個使用 functorch 和 torch.func 計算模型引數梯度的示例
# ---------------
# using functorch
# ---------------
import torch
import functorch
inputs = torch.randn(64, 3)
targets = torch.randn(64, 3)
model = torch.nn.Linear(3, 3)
fmodel, params = functorch.make_functional(model)
def compute_loss(params, inputs, targets):
prediction = fmodel(params, inputs)
return torch.nn.functional.mse_loss(prediction, targets)
grads = functorch.grad(compute_loss)(params, inputs, targets)
# ------------------------------------
# using torch.func (as of PyTorch 2.0)
# ------------------------------------
import torch
inputs = torch.randn(64, 3)
targets = torch.randn(64, 3)
model = torch.nn.Linear(3, 3)
params = dict(model.named_parameters())
def compute_loss(params, inputs, targets):
prediction = torch.func.functional_call(model, params, (inputs,))
return torch.nn.functional.mse_loss(prediction, targets)
grads = torch.func.grad(compute_loss)(params, inputs, targets)
這裡有一個計算模型引數 Jacobian 矩陣的示例
# ---------------
# using functorch
# ---------------
import torch
import functorch
inputs = torch.randn(64, 3)
model = torch.nn.Linear(3, 3)
fmodel, params = functorch.make_functional(model)
jacobians = functorch.jacrev(fmodel)(params, inputs)
# ------------------------------------
# using torch.func (as of PyTorch 2.0)
# ------------------------------------
import torch
from torch.func import jacrev, functional_call
inputs = torch.randn(64, 3)
model = torch.nn.Linear(3, 3)
params = dict(model.named_parameters())
# jacrev computes jacobians of argnums=0 by default.
# We set it to 1 to compute jacobians of params
jacobians = jacrev(functional_call, argnums=1)(model, params, (inputs,))
請注意,為了控制記憶體消耗,只應保留引數的單個副本,這一點很重要。model.named_parameters() 不會複製引數。如果在模型訓練中原地更新模型引數,則模型中的 nn.Module 擁有引數的單個副本,一切正常。
但是,如果想將引數儲存在字典中並非原地更新,則會有兩份引數副本:一份在字典中,另一份在 model 中。在這種情況下,應該透過 model.to('meta') 將 model 轉換為 meta 裝置,使其不持有記憶體。
functorch.combine_state_for_ensemble¶
請使用 torch.func.stack_module_state() 替代 functorch.combine_state_for_ensemble。torch.func.stack_module_state() 返回兩個字典,一個包含堆疊的引數,另一個包含堆疊的緩衝區,然後可以與 torch.vmap() 和 torch.func.functional_call() 一起用於整合(ensembling)。
例如,這裡有一個關於如何對一個非常簡單的模型進行整合(ensemble)的示例
import torch
num_models = 5
batch_size = 64
in_features, out_features = 3, 3
models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)]
data = torch.randn(batch_size, 3)
# ---------------
# using functorch
# ---------------
import functorch
fmodel, params, buffers = functorch.combine_state_for_ensemble(models)
output = functorch.vmap(fmodel, (0, 0, None))(params, buffers, data)
assert output.shape == (num_models, batch_size, out_features)
# ------------------------------------
# using torch.func (as of PyTorch 2.0)
# ------------------------------------
import copy
# Construct a version of the model with no memory by putting the Tensors on
# the meta device.
base_model = copy.deepcopy(models[0])
base_model.to('meta')
params, buffers = torch.func.stack_module_state(models)
# It is possible to vmap directly over torch.func.functional_call,
# but wrapping it in a function makes it clearer what is going on.
def call_single_model(params, buffers, data):
return torch.func.functional_call(base_model, (params, buffers), (data,))
output = torch.vmap(call_single_model, (0, 0, None))(params, buffers, data)
assert output.shape == (num_models, batch_size, out_features)
functorch.compile¶
我們不再支援將 functorch.compile(也稱為 AOTAutograd)作為 PyTorch 編譯的前端;我們已將 AOTAutograd 整合到 PyTorch 的編譯體系中。如果您是使用者,請轉而使用 torch.compile()。