快捷方式

從 functorch 遷移到 torch.func

torch.func,之前稱為 “functorch”,是 PyTorch 中類似 JAX 的可組合函式變換。

functorch 最初是 pytorch/functorch 倉庫中的一個獨立庫。我們的目標始終是將 functorch 直接合併到 PyTorch 主幹中,並將其作為核心 PyTorch 庫提供。

作為合併主幹的最後一步,我們決定從一個頂級包(functorch)遷移到成為 PyTorch 的一部分,以反映函式變換是如何直接整合到 PyTorch 核心中的。從 PyTorch 2.0 開始,我們棄用了 import functorch,並要求使用者遷移到我們將繼續維護的最新 API。import functorch 將保留幾個版本,以保持向後相容性。

函式變換

以下 API 可以直接替換以下 functorch API。它們完全向後相容。

functorch API

PyTorch API (截至 PyTorch 2.0)

functorch.vmap

torch.vmap()torch.func.vmap()

functorch.grad

torch.func.grad()

functorch.vjp

torch.func.vjp()

functorch.jvp

torch.func.jvp()

functorch.jacrev

torch.func.jacrev()

functorch.jacfwd

torch.func.jacfwd()

functorch.hessian

torch.func.hessian()

functorch.functionalize

torch.func.functionalize()

此外,如果您正在使用 torch.autograd.functional API,請嘗試使用相應的 torch.func 替代項。torch.func 函式變換在許多情況下更具可組合性和更高效能。

torch.autograd.functional API

torch.func API (截至 PyTorch 2.0)

torch.autograd.functional.vjp()

torch.func.grad()torch.func.vjp()

torch.autograd.functional.jvp()

torch.func.jvp()

torch.autograd.functional.jacobian()

torch.func.jacrev()torch.func.jacfwd()

torch.autograd.functional.hessian()

torch.func.hessian()

NN 模組工具

我們更改了應用於 NN 模組的函式變換 API,使其更符合 PyTorch 的設計理念。新的 API 有所不同,請仔細閱讀本節。

functorch.make_functional

torch.func.functional_call() 替代了 functorch.make_functionalfunctorch.make_functional_with_buffers。然而,它不能完全直接替換。

如果時間緊急,您可以使用此 Gist 中的輔助函式來模擬 functorch.make_functional 和 functorch.make_functional_with_buffers 的行為。我們建議直接使用 torch.func.functional_call(),因為它是一個更明確、更靈活的 API。

具體來說,functorch.make_functional 返回一個函式式模組和引數。該函式式模組接受引數和模型輸入作為引數。torch.func.functional_call() 允許使用新的引數、緩衝區和輸入呼叫現有模組的前向傳遞。

這裡有一個使用 functorch 和 torch.func 計算模型引數梯度的示例

# ---------------
# using functorch
# ---------------
import torch
import functorch
inputs = torch.randn(64, 3)
targets = torch.randn(64, 3)
model = torch.nn.Linear(3, 3)

fmodel, params = functorch.make_functional(model)

def compute_loss(params, inputs, targets):
    prediction = fmodel(params, inputs)
    return torch.nn.functional.mse_loss(prediction, targets)

grads = functorch.grad(compute_loss)(params, inputs, targets)

# ------------------------------------
# using torch.func (as of PyTorch 2.0)
# ------------------------------------
import torch
inputs = torch.randn(64, 3)
targets = torch.randn(64, 3)
model = torch.nn.Linear(3, 3)

params = dict(model.named_parameters())

def compute_loss(params, inputs, targets):
    prediction = torch.func.functional_call(model, params, (inputs,))
    return torch.nn.functional.mse_loss(prediction, targets)

grads = torch.func.grad(compute_loss)(params, inputs, targets)

這裡有一個計算模型引數 Jacobian 矩陣的示例

# ---------------
# using functorch
# ---------------
import torch
import functorch
inputs = torch.randn(64, 3)
model = torch.nn.Linear(3, 3)

fmodel, params = functorch.make_functional(model)
jacobians = functorch.jacrev(fmodel)(params, inputs)

# ------------------------------------
# using torch.func (as of PyTorch 2.0)
# ------------------------------------
import torch
from torch.func import jacrev, functional_call
inputs = torch.randn(64, 3)
model = torch.nn.Linear(3, 3)

params = dict(model.named_parameters())
# jacrev computes jacobians of argnums=0 by default.
# We set it to 1 to compute jacobians of params
jacobians = jacrev(functional_call, argnums=1)(model, params, (inputs,))

請注意,為了控制記憶體消耗,只應保留引數的單個副本,這一點很重要。model.named_parameters() 不會複製引數。如果在模型訓練中原地更新模型引數,則模型中的 nn.Module 擁有引數的單個副本,一切正常。

但是,如果想將引數儲存在字典中並非原地更新,則會有兩份引數副本:一份在字典中,另一份在 model 中。在這種情況下,應該透過 model.to('meta')model 轉換為 meta 裝置,使其不持有記憶體。

functorch.combine_state_for_ensemble

請使用 torch.func.stack_module_state() 替代 functorch.combine_state_for_ensembletorch.func.stack_module_state() 返回兩個字典,一個包含堆疊的引數,另一個包含堆疊的緩衝區,然後可以與 torch.vmap()torch.func.functional_call() 一起用於整合(ensembling)。

例如,這裡有一個關於如何對一個非常簡單的模型進行整合(ensemble)的示例

import torch
num_models = 5
batch_size = 64
in_features, out_features = 3, 3
models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)]
data = torch.randn(batch_size, 3)

# ---------------
# using functorch
# ---------------
import functorch
fmodel, params, buffers = functorch.combine_state_for_ensemble(models)
output = functorch.vmap(fmodel, (0, 0, None))(params, buffers, data)
assert output.shape == (num_models, batch_size, out_features)

# ------------------------------------
# using torch.func (as of PyTorch 2.0)
# ------------------------------------
import copy

# Construct a version of the model with no memory by putting the Tensors on
# the meta device.
base_model = copy.deepcopy(models[0])
base_model.to('meta')

params, buffers = torch.func.stack_module_state(models)

# It is possible to vmap directly over torch.func.functional_call,
# but wrapping it in a function makes it clearer what is going on.
def call_single_model(params, buffers, data):
    return torch.func.functional_call(base_model, (params, buffers), (data,))

output = torch.vmap(call_single_model, (0, 0, None))(params, buffers, data)
assert output.shape == (num_models, batch_size, out_features)

functorch.compile

我們不再支援將 functorch.compile(也稱為 AOTAutograd)作為 PyTorch 編譯的前端;我們已將 AOTAutograd 整合到 PyTorch 的編譯體系中。如果您是使用者,請轉而使用 torch.compile()

文件

查閱 PyTorch 全面的開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深入教程

檢視教程

資源

查詢開發資源並獲得解答

檢視資源