torch.nn.utils.stateless.functional_call¶
- torch.nn.utils.stateless.functional_call(module, parameters_and_buffers, args=None, kwargs=None, *, tie_weights=True, strict=False)[原始碼][原始碼]¶
透過用提供的引數和緩衝區替換模組的引數和緩衝區來執行模組上的函式式呼叫。
警告
此 API 已在 PyTorch 2.0 中棄用,並將在未來版本中移除。請改用
torch.func.functional_call(),它是此 API 的直接替代品。注意
如果模組具有活躍的引數化 (parametrization),在
parameters_and_buffers引數中傳入一個名稱設定為常規引數名稱的值將完全停用該引數化。如果要將引數化函式應用於傳入的值,請將鍵設定為{submodule_name}.parametrizations.{parameter_name}.original。注意
如果模組對引數/緩衝區執行原地 (in-place) 操作,這些操作將反映在 parameters_and_buffers 輸入中。
示例
>>> a = {'foo': torch.zeros(())} >>> mod = Foo() # does self.foo = self.foo + 1 >>> print(mod.foo) # tensor(0.) >>> functional_call(mod, a, torch.ones(())) >>> print(mod.foo) # tensor(0.) >>> print(a['foo']) # tensor(1.)
注意
如果模組具有繫結權重 (tied weights),則 functional_call 是否遵循繫結取決於 tie_weights 標誌。
示例
>>> a = {'foo': torch.zeros(())} >>> mod = Foo() # has both self.foo and self.foo_tied which are tied. Returns x + self.foo + self.foo_tied >>> print(mod.foo) # tensor(1.) >>> mod(torch.zeros(())) # tensor(2.) >>> functional_call(mod, a, torch.zeros(())) # tensor(0.) since it will change self.foo_tied too >>> functional_call(mod, a, torch.zeros(()), tie_weights=False) # tensor(1.)--self.foo_tied is not updated >>> new_a = {'foo': torch.zeros(()), 'foo_tied': torch.zeros(())} >>> functional_call(mod, new_a, torch.zeros()) # tensor(0.)
- 引數
module (torch.nn.Module) – 要呼叫的模組
parameters_and_buffers (dict of str and Tensor) – 將在模組呼叫中使用的引數。
args (Any or tuple) – 要傳遞給模組呼叫的引數。如果不是元組,則視為單個引數。
kwargs (dict) – 要傳遞給模組呼叫的關鍵字引數
tie_weights (bool, optional) – 如果為 True,則原始模型中繫結的引數和緩衝區在重新引數化版本中也將被視為繫結。因此,如果為 True 且為繫結的引數和緩衝區傳入了不同的值,將引發錯誤。如果為 False,則不會遵循原始繫結的引數和緩衝區,除非為兩個權重傳入的值相同。預設值:True。
strict (bool, optional) – 如果為 True,則傳入的引數和緩衝區必須與原始模組中的引數和緩衝區匹配。因此,如果為 True 且存在任何缺失或意外的鍵,將引發錯誤。預設值:False。
- 返回值
呼叫
module的結果。- 返回值型別
Any