快捷方式

torch.func API 參考

函式變換

vmap

vmap 是向量化對映;vmap(func) 返回一個新函式,該函式將 func 對映到輸入的某個維度上。

grad

grad 運算子有助於計算 func 相對於 argnums 指定的輸入(或多個輸入)的梯度。

grad_and_value

返回一個函式,用於計算梯度和原始計算(或前向計算)組成的元組。

vjp

vjp 代表向量-雅可比乘積 (vector-Jacobian product),返回一個元組,其中包含應用於 primalsfunc 的結果,以及一個給定 cotangents 後計算 func 相對於 primals 的反向模式雅可比(再乘以 cotangents)的函式。

jvp

jvp 代表雅可比-向量乘積 (Jacobian-vector product),返回一個元組,其中包含 func(*primals) 的輸出以及“在 primals 處計算的 func 的雅可比”乘以 tangents 的結果。

linearize

返回 funcprimals 處的值以及在 primals 處的線性近似值。

jacrev

使用反向模式自動微分計算 func 相對於索引 argnum 處的引數(或多個引數)的雅可比。

jacfwd

使用前向模式自動微分計算 func 相對於索引 argnum 處的引數(或多個引數)的雅可比。

hessian

透過前向-後向策略計算 func 相對於索引 argnum 處的引數(或多個引數)的 Hessian 矩陣。

functionalize

functionalize 是一種變換,可用於從函式中去除(中間)修改和別名,同時保留函式的語義。

用於處理 torch.nn.Modules 的工具

通常,您可以對呼叫 torch.nn.Module 的函式進行變換。例如,下面是一個計算接受三個值並返回三個值的函式的雅可比的示例

model = torch.nn.Linear(3, 3)

def f(x):
    return model(x)

x = torch.randn(3)
jacobian = jacrev(f)(x)
assert jacobian.shape == (3, 3)

但是,如果您想對模型的引數計算雅可比之類的操作,則需要一種方法來構造一個將引數作為函式輸入的函式。這就是 functional_call() 的作用:它接受一個 nn.Module、變換後的 parameters 以及 Module 前向傳播的輸入。它返回使用替換引數執行 Module 前向傳播的值。

下面是如何計算引數的雅可比

model = torch.nn.Linear(3, 3)

def f(params, x):
    return torch.func.functional_call(model, params, x)

x = torch.randn(3)
jacobian = jacrev(f)(dict(model.named_parameters()), x)

functional_call

透過用提供的引數和緩衝區替換模組的引數和緩衝區,對模組執行函式式呼叫。

stack_module_state

準備一個 torch.nn.Modules 列表,以便與 vmap() 一起進行整合。

replace_all_batch_norm_modules_

透過將 running_meanrunning_var 設定為 None 並將 track_running_stats 設定為 False,原地更新 root 中的任何 nn.BatchNorm 模組。

如果您正在尋找關於修復 Batch Norm 模組的資訊,請遵循此處的指南

除錯工具

debug_unwrap

展開一個 functorch tensor(例如

文件

查閱 PyTorch 的全面開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源