torch.func.functional_call¶
- torch.func.functional_call(module, parameter_and_buffer_dicts, args=None, kwargs=None, *, tie_weights=True, strict=False)[源]¶
透過替換模組引數和緩衝區來對模組執行函式式呼叫。
注意
如果模組具有活躍的引數化 (parametrizations),在
parameter_and_buffer_dicts引數中傳入一個值,其名稱設定為常規引數名,將完全停用引數化。如果您想對傳入的值應用引數化函式,請將鍵設定為{submodule_name}.parametrizations.{parameter_name}.original。注意
如果模組對引數/緩衝區執行原地 (in-place) 操作,這些操作將反映在
parameter_and_buffer_dicts輸入中。示例
>>> a = {'foo': torch.zeros(())} >>> mod = Foo() # does self.foo = self.foo + 1 >>> print(mod.foo) # tensor(0.) >>> functional_call(mod, a, torch.ones(())) >>> print(mod.foo) # tensor(0.) >>> print(a['foo']) # tensor(1.)
注意
如果模組有共享權重 (tied weights),
functional_call是否遵循共享取決於tie_weights標誌。示例
>>> a = {'foo': torch.zeros(())} >>> mod = Foo() # has both self.foo and self.foo_tied which are tied. Returns x + self.foo + self.foo_tied >>> print(mod.foo) # tensor(1.) >>> mod(torch.zeros(())) # tensor(2.) >>> functional_call(mod, a, torch.zeros(())) # tensor(0.) since it will change self.foo_tied too >>> functional_call(mod, a, torch.zeros(()), tie_weights=False) # tensor(1.)--self.foo_tied is not updated >>> new_a = {'foo': torch.zeros(()), 'foo_tied': torch.zeros(())} >>> functional_call(mod, new_a, torch.zeros()) # tensor(0.)
傳入多個字典的示例
a = ({'weight': torch.ones(1, 1)}, {'buffer': torch.zeros(1)}) # two separate dictionaries mod = nn.Bar(1, 1) # return self.weight @ x + self.buffer print(mod.weight) # tensor(...) print(mod.buffer) # tensor(...) x = torch.randn((1, 1)) print(x) functional_call(mod, a, x) # same as x print(mod.weight) # same as before functional_call
這裡是應用梯度變換 (grad transform) 到模型引數的示例。
import torch import torch.nn as nn from torch.func import functional_call, grad x = torch.randn(4, 3) t = torch.randn(4, 3) model = nn.Linear(3, 3) def compute_loss(params, x, t): y = functional_call(model, params, x) return nn.functional.mse_loss(y, t) grad_weights = grad(compute_loss)(dict(model.named_parameters()), x, t)
注意
如果使用者在梯度變換之外不需要梯度跟蹤,他們可以分離所有引數,以獲得更好的效能和記憶體使用。
示例
>>> detached_params = {k: v.detach() for k, v in model.named_parameters()} >>> grad_weights = grad(compute_loss)(detached_params, x, t) >>> grad_weights.grad_fn # None--it's not tracking gradients outside of grad
這意味著使用者無法呼叫
grad_weight.backward()。但是,如果他們不需要在變換之外進行自動微分跟蹤,這將減少記憶體使用並提高速度。- 引數
module (torch.nn.Module) – 要呼叫的模組
parameters_and_buffer_dicts (Dict[str, Tensor] 或 tuple of Dict[str, Tensor]) – 將在模組呼叫中使用的引數和緩衝區。如果給定一個字典元組,它們的鍵必須是唯一的,以便所有字典可以一起使用。
args (Any 或 tuple) – 傳遞給模組呼叫的位置引數。如果不是元組,則視為單個引數。
kwargs (dict) – 傳遞給模組呼叫的關鍵字引數
tie_weights (bool, 可選) – 如果為 True,則原始模型中共享的引數和緩衝區在重新引數化版本中也將被視為共享。因此,如果為 True 且為共享引數和緩衝區傳遞了不同的值,則會報錯。如果為 False,則不會遵循原始的共享引數和緩衝區,除非為兩個權重傳遞的值相同。預設為:True。
strict (bool, 可選) – 如果為 True,則傳入的引數和緩衝區必須與原始模組中的引數和緩衝區匹配。因此,如果為 True 且存在任何缺失或意外的鍵,則會報錯。預設為:False。
- 返回
呼叫
module的結果。- 返回型別
Any