torch.nn.utils.skip_init¶
- torch.nn.utils.skip_init(module_cls, *args, **kwargs)[原始碼][原始碼]¶
給定一個模組類物件以及 args / kwargs,在不初始化引數 / 緩衝區的情況下例項化該模組。
如果初始化速度慢,或者如果將執行自定義初始化,從而使得預設初始化不必要,這可能很有用。由於此函式的實現方式,存在一些注意事項:
1. 模組的建構函式必須接受一個 device 引數,該引數會傳遞給在構建過程中建立的任何引數或緩衝區。
2. 除了初始化(即來自
torch.nn.init的函式)外,模組的建構函式不得對引數執行任何計算。如果滿足這些條件,則可以例項化引數 / 緩衝區值未初始化的模組,就像使用
torch.empty()建立一樣。- 引數
module_cls – 類物件;應該是
torch.nn.Module的子類。args – 傳遞給模組建構函式的 args。
kwargs – 傳遞給模組建構函式的 kwargs。
- 返回
例項化後引數 / 緩衝區未初始化的模組。
示例
>>> import torch >>> m = torch.nn.utils.skip_init(torch.nn.Linear, 5, 1) >>> m.weight Parameter containing: tensor([[0.0000e+00, 1.5846e+29, 7.8307e+00, 2.5250e-29, 1.1210e-44]], requires_grad=True) >>> m2 = torch.nn.utils.skip_init(torch.nn.Linear, in_features=6, out_features=1) >>> m2.weight Parameter containing: tensor([[-1.4677e+24, 4.5915e-41, 1.4013e-45, 0.0000e+00, -1.4677e+24, 4.5915e-41]], requires_grad=True)