torch.func API 參考¶
函式變換¶
vmap 是向量化對映; |
|
|
|
返回一個函式,用於計算梯度和原始計算(或前向計算)組成的元組。 |
|
vjp 代表向量-雅可比乘積 (vector-Jacobian product),返回一個元組,其中包含應用於 |
|
jvp 代表雅可比-向量乘積 (Jacobian-vector product),返回一個元組,其中包含 func(*primals) 的輸出以及“在 |
|
返回 |
|
使用反向模式自動微分計算 |
|
使用前向模式自動微分計算 |
|
透過前向-後向策略計算 |
|
functionalize 是一種變換,可用於從函式中去除(中間)修改和別名,同時保留函式的語義。 |
用於處理 torch.nn.Modules 的工具¶
通常,您可以對呼叫 torch.nn.Module 的函式進行變換。例如,下面是一個計算接受三個值並返回三個值的函式的雅可比的示例
model = torch.nn.Linear(3, 3)
def f(x):
return model(x)
x = torch.randn(3)
jacobian = jacrev(f)(x)
assert jacobian.shape == (3, 3)
但是,如果您想對模型的引數計算雅可比之類的操作,則需要一種方法來構造一個將引數作為函式輸入的函式。這就是 functional_call() 的作用:它接受一個 nn.Module、變換後的 parameters 以及 Module 前向傳播的輸入。它返回使用替換引數執行 Module 前向傳播的值。
下面是如何計算引數的雅可比
model = torch.nn.Linear(3, 3)
def f(params, x):
return torch.func.functional_call(model, params, x)
x = torch.randn(3)
jacobian = jacrev(f)(dict(model.named_parameters()), x)
透過用提供的引數和緩衝區替換模組的引數和緩衝區,對模組執行函式式呼叫。 |
|
準備一個 torch.nn.Modules 列表,以便與 |
|
透過將 |
如果您正在尋找關於修復 Batch Norm 模組的資訊,請遵循此處的指南
除錯工具¶
展開一個 functorch tensor(例如 |