捷徑

torch.func API 參考

函數轉換

vmap

vmap 是向量化映射;vmap(func) 會傳回一個新函數,該函數會將 func 映射到輸入的某個維度。

grad

grad 運算子有助於計算 func 相對於 argnums 指定的輸入的梯度。

grad_and_value

傳回一個函數,用於計算梯度和原始計算(或正向計算)的元組。

vjp

代表向量-雅可比矩陣乘積,傳回一個元組,其中包含將 func 應用於 primals 的結果,以及一個函數,當給定 cotangents 時,會計算 func 相對於 primals 的反向模式雅可比矩陣乘以 cotangents

jvp

代表雅可比矩陣-向量乘積,傳回一個元組,其中包含 func(*primals) 的輸出以及「在 primals 處評估的 func 的雅可比矩陣」乘以 tangents

linearize

傳回 funcprimals 處的值和在 primals 處的線性近似值。

jacrev

使用反向模式自動微分計算 func 相對於索引 argnum 處的參數的雅可比矩陣

jacfwd

使用正向模式自動微分計算 func 相對於索引 argnum 處的參數的雅可比矩陣

hessian

透過正向對反向策略計算 func 相對於索引 argnum 處的參數的海森矩陣。

functionalize

functionalize 是一種轉換,可用於從函數中移除(中間)突變和別名,同時保留函數的語義。

用於處理 torch.nn.Modules 的工具

一般而言,您可以對呼叫 torch.nn.Module 的函數進行轉換。例如,以下是一個計算雅可比矩陣的範例,該函數採用三個值並傳回三個值

model = torch.nn.Linear(3, 3)

def f(x):
    return model(x)

x = torch.randn(3)
jacobian = jacrev(f)(x)
assert jacobian.shape == (3, 3)

但是,如果您想對模型的參數進行計算雅可比矩陣之類的操作,則需要一種方法來建構一個函數,其中參數是函數的輸入。這就是 functional_call() 的用途:它接受一個 nn.Module、轉換後的 參數,以及模組正向傳遞的輸入。它會傳回使用替換後的參數執行模組正向傳遞的值。

以下是我們如何計算參數的雅可比矩陣

model = torch.nn.Linear(3, 3)

def f(params, x):
    return torch.func.functional_call(model, params, x)

x = torch.randn(3)
jacobian = jacrev(f)(dict(model.named_parameters()), x)

functional_call

透過將模組參數和緩衝區替換為提供的參數和緩衝區,對模組執行函數呼叫。

stack_module_state

準備一個 torch.nn.Modules 清單,以便使用 vmap() 進行整合。

replace_all_batch_norm_modules_

透過將 running_meanrunning_var 設定為 None 並將 track_running_stats 設定為 False,對 root 中的任何 nn.BatchNorm 模組進行就地更新 root

如果您正在尋找有關修復批次正規化模組的資訊,請遵循此處的指南

文件

存取 PyTorch 的完整開發人員文件

查看文件

教學課程

取得適用於初學者和進階開發人員的深入教學課程

查看教學課程

資源

尋找開發資源並獲得問題解答

查看資源