捷徑

從 functorch 遷移到 torch.func

torch.func,以前稱為「functorch」,是 PyTorch 的 JAX 類似 可組合函數轉換。

functorch 最初是一個位於 pytorch/functorch 儲存庫的樹外函式庫。我們的目標始終是將 functorch 直接上游到 PyTorch 中,並將其作為核心 PyTorch 函式庫提供。

作為上游的最後一步,我們決定從頂級套件 (functorch) 遷移到成為 PyTorch 的一部分,以反映函數轉換如何直接整合到 PyTorch 核心。從 PyTorch 2.0 開始,我們將棄用 import functorch,並要求使用者遷移到我們將繼續維護的最新 API。import functorch 將保留幾個版本,以維持向後相容性。

函數轉換

以下 API 是以下 functorch API 的直接替代品。它們完全向後相容。

functorch API

PyTorch API(從 PyTorch 2.0 開始)

functorch.vmap

torch.vmap()torch.func.vmap()

functorch.grad

torch.func.grad()

functorch.vjp

torch.func.vjp()

functorch.jvp

torch.func.jvp()

functorch.jacrev

torch.func.jacrev()

functorch.jacfwd

torch.func.jacfwd()

functorch.hessian

torch.func.hessian()

functorch.functionalize

torch.func.functionalize()

此外,如果您使用的是 torch.autograd.functional API,請嘗試使用 torch.func 等效項。torch.func 函數轉換在許多情況下更具可組合性和效能。

torch.autograd.functional API

torch.func API(從 PyTorch 2.0 開始)

torch.autograd.functional.vjp()

torch.func.grad()torch.func.vjp()

torch.autograd.functional.jvp()

torch.func.jvp()

torch.autograd.functional.jacobian()

torch.func.jacrev()torch.func.jacfwd()

torch.autograd.functional.hessian()

torch.func.hessian()

NN 模組公用程式

我們已更改在 NN 模組上套用函數轉換的 API,使其更符合 PyTorch 的設計理念。新的 API 不同,因此請仔細閱讀本節。

functorch.make_functional

torch.func.functional_call()functorch.make_functionalfunctorch.make_functional_with_buffers 的替代品。但是,它不是直接替代品。

如果您趕時間,可以使用 此 gist 中的輔助函數 來模擬 functorch.make_functional 和 functorch.make_functional_with_buffers 的行為。我們建議直接使用 torch.func.functional_call(),因為它是一個更明確且更靈活的 API。

具體來說,functorch.make_functional 會傳回一個函數模組和參數。函數模組接受模型的參數和輸入作為參數。torch.func.functional_call() 允許使用者使用新的參數、緩衝區和輸入來呼叫現有模組的前向傳遞。

以下是如何使用 functorch 和 torch.func 計算模型參數梯度的範例

# ---------------
# using functorch
# ---------------
import torch
import functorch
inputs = torch.randn(64, 3)
targets = torch.randn(64, 3)
model = torch.nn.Linear(3, 3)

fmodel, params = functorch.make_functional(model)

def compute_loss(params, inputs, targets):
    prediction = fmodel(params, inputs)
    return torch.nn.functional.mse_loss(prediction, targets)

grads = functorch.grad(compute_loss)(params, inputs, targets)

# ------------------------------------
# using torch.func (as of PyTorch 2.0)
# ------------------------------------
import torch
inputs = torch.randn(64, 3)
targets = torch.randn(64, 3)
model = torch.nn.Linear(3, 3)

params = dict(model.named_parameters())

def compute_loss(params, inputs, targets):
    prediction = torch.func.functional_call(model, params, (inputs,))
    return torch.nn.functional.mse_loss(prediction, targets)

grads = torch.func.grad(compute_loss)(params, inputs, targets)

以下是如何計算模型參數雅可比矩陣的範例

# ---------------
# using functorch
# ---------------
import torch
import functorch
inputs = torch.randn(64, 3)
model = torch.nn.Linear(3, 3)

fmodel, params = functorch.make_functional(model)
jacobians = functorch.jacrev(fmodel)(params, inputs)

# ------------------------------------
# using torch.func (as of PyTorch 2.0)
# ------------------------------------
import torch
from torch.func import jacrev, functional_call
inputs = torch.randn(64, 3)
model = torch.nn.Linear(3, 3)

params = dict(model.named_parameters())
# jacrev computes jacobians of argnums=0 by default.
# We set it to 1 to compute jacobians of params
jacobians = jacrev(functional_call, argnums=1)(model, params, (inputs,))

請注意,對於記憶體消耗而言,重要的是您應該只保留一份參數副本。model.named_parameters() 不會複製參數。如果在模型訓練中,您就地更新模型的參數,則身為模型的 nn.Module 會擁有參數的單一副本,並且一切正常。

然而,如果您想將參數保存在字典中並在外部更新它們,則參數會有兩個副本:一個在字典中,另一個在 model 中。在這種情況下,您應該透過 model.to('meta')model 轉換為元設備,使其不佔用內存。

functorch.combine_state_for_ensemble

請使用 torch.func.stack_module_state() 來代替 functorch.combine_state_for_ensembletorch.func.stack_module_state() 會返回兩個字典,一個是堆疊的參數,另一個是堆疊的緩衝區,然後可以將它們與 torch.vmap()torch.func.functional_call() 一起使用來進行集成。

例如,以下是如何對一個非常簡單的模型進行集成的示例

import torch
num_models = 5
batch_size = 64
in_features, out_features = 3, 3
models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)]
data = torch.randn(batch_size, 3)

# ---------------
# using functorch
# ---------------
import functorch
fmodel, params, buffers = functorch.combine_state_for_ensemble(models)
output = functorch.vmap(fmodel, (0, 0, None))(params, buffers, data)
assert output.shape == (num_models, batch_size, out_features)

# ------------------------------------
# using torch.func (as of PyTorch 2.0)
# ------------------------------------
import copy

# Construct a version of the model with no memory by putting the Tensors on
# the meta device.
base_model = copy.deepcopy(models[0])
base_model.to('meta')

params, buffers = torch.func.stack_module_state(models)

# It is possible to vmap directly over torch.func.functional_call,
# but wrapping it in a function makes it clearer what is going on.
def call_single_model(params, buffers, data):
    return torch.func.functional_call(base_model, (params, buffers), (data,))

output = torch.vmap(call_single_model, (0, 0, None))(params, buffers, data)
assert output.shape == (num_models, batch_size, out_features)

functorch.compile

我們不再支持 functorch.compile(也稱為 AOTAutograd)作為 PyTorch 中編譯的前端;我們已將 AOTAutograd 整合到 PyTorch 的編譯流程中。如果您是使用者,請改用 torch.compile()

文件

取得 PyTorch 的完整開發人員文件

查看文件

教學課程

取得適用於初學者和進階開發人員的深入教學課程

查看教學課程

資源

尋找開發資源並獲得問題的解答

查看資源