torch.func API 參考¶
函數轉換¶
| vmap 是向量化映射; | |
| 
 | |
| 傳回一個函數,用於計算梯度和原始計算(或正向計算)的元組。 | |
| 代表向量-雅可比矩陣乘積,傳回一個元組,其中包含將  | |
| 代表雅可比矩陣-向量乘積,傳回一個元組,其中包含 func(*primals) 的輸出以及「在  | |
| 傳回  | |
| 使用反向模式自動微分計算  | |
| 使用正向模式自動微分計算  | |
| 透過正向對反向策略計算  | |
| functionalize 是一種轉換,可用於從函數中移除(中間)突變和別名,同時保留函數的語義。 | 
用於處理 torch.nn.Modules 的工具¶
一般而言,您可以對呼叫 torch.nn.Module 的函數進行轉換。例如,以下是一個計算雅可比矩陣的範例,該函數採用三個值並傳回三個值
model = torch.nn.Linear(3, 3)
def f(x):
    return model(x)
x = torch.randn(3)
jacobian = jacrev(f)(x)
assert jacobian.shape == (3, 3)
但是,如果您想對模型的參數進行計算雅可比矩陣之類的操作,則需要一種方法來建構一個函數,其中參數是函數的輸入。這就是 functional_call() 的用途:它接受一個 nn.Module、轉換後的 參數,以及模組正向傳遞的輸入。它會傳回使用替換後的參數執行模組正向傳遞的值。
以下是我們如何計算參數的雅可比矩陣
model = torch.nn.Linear(3, 3)
def f(params, x):
    return torch.func.functional_call(model, params, x)
x = torch.randn(3)
jacobian = jacrev(f)(dict(model.named_parameters()), x)
| 透過將模組參數和緩衝區替換為提供的參數和緩衝區,對模組執行函數呼叫。 | |
| 準備一個 torch.nn.Modules 清單,以便使用  | |
| 透過將  | 
如果您正在尋找有關修復批次正規化模組的資訊,請遵循此處的指南