LazyModuleMixin¶
- class torch.nn.modules.lazy.LazyModuleMixin(*args, **kwargs)[原始碼][原始碼]¶
用於延遲初始化引數的模組的 mixin 類,也稱為“延遲模組”。
延遲初始化引數的模組,即“延遲模組”,會從其 forward 方法的第一個輸入中推斷其引數的形狀。在第一次 forward 之前,它們包含
torch.nn.UninitializedParameter,不應訪問或使用;之後,它們包含常規的torch.nn.Parameter。延遲模組很方便,因為它們不需要計算某些模組引數,例如典型的torch.nn.Linear的in_features引數。構造之後,包含延遲模組的網路應首先轉換為所需的 dtype 並放置在預期的裝置上。這是因為延遲模組僅執行形狀推斷,因此常規的 dtype 和裝置放置行為仍然適用。然後,延遲模組應執行“空執行”(dry run)來初始化模組中的所有元件。這些“空執行”會將具有正確大小、dtype 和裝置的輸入透過網路傳送到每個延遲模組。之後,網路就可以正常使用了。
>>> class LazyMLP(torch.nn.Module): ... def __init__(self) -> None: ... super().__init__() ... self.fc1 = torch.nn.LazyLinear(10) ... self.relu1 = torch.nn.ReLU() ... self.fc2 = torch.nn.LazyLinear(1) ... self.relu2 = torch.nn.ReLU() ... ... def forward(self, input): ... x = self.relu1(self.fc1(input)) ... y = self.relu2(self.fc2(x)) ... return y >>> # constructs a network with lazy modules >>> lazy_mlp = LazyMLP() >>> # transforms the network's device and dtype >>> # NOTE: these transforms can and should be applied after construction and before any 'dry runs' >>> lazy_mlp = lazy_mlp.cuda().double() >>> lazy_mlp LazyMLP( (fc1): LazyLinear(in_features=0, out_features=10, bias=True) (relu1): ReLU() (fc2): LazyLinear(in_features=0, out_features=1, bias=True) (relu2): ReLU() ) >>> # performs a dry run to initialize the network's lazy modules >>> lazy_mlp(torch.ones(10,10).cuda()) >>> # after initialization, LazyLinear modules become regular Linear modules >>> lazy_mlp LazyMLP( (fc1): Linear(in_features=10, out_features=10, bias=True) (relu1): ReLU() (fc2): Linear(in_features=10, out_features=1, bias=True) (relu2): ReLU() ) >>> # attaches an optimizer, since parameters can now be used as usual >>> optim = torch.optim.SGD(mlp.parameters(), lr=0.01)
使用延遲模組時的最後一個注意事項是,網路的引數初始化順序可能會發生變化,因為延遲模組總是在其他模組之後初始化。例如,如果上面定義的 LazyMLP 類首先包含
torch.nn.LazyLinear模組,然後包含一個常規的torch.nn.Linear模組,那麼第二個模組會在構造時初始化,而第一個模組會在第一次空執行時初始化。這可能導致使用延遲模組的網路的引數初始化方式與不使用延遲模組的網路不同,因為引數初始化的順序不同,這通常取決於一個有狀態的隨機數生成器。有關更多詳細資訊,請參閱可復現性。延遲模組可以像其他模組一樣使用 state dict 進行序列化。例如
>>> lazy_mlp = LazyMLP() >>> # The state dict shows the uninitialized parameters >>> lazy_mlp.state_dict() OrderedDict([('fc1.weight', Uninitialized parameter), ('fc1.bias', tensor([-1.8832e+25, 4.5636e-41, -1.8832e+25, 4.5636e-41, -6.1598e-30, 4.5637e-41, -1.8788e+22, 4.5636e-41, -2.0042e-31, 4.5637e-41])), ('fc2.weight', Uninitialized parameter), ('fc2.bias', tensor([0.0019]))])
延遲模組可以載入常規的
torch.nn.Parameter(即,您可以序列化/反序列化已初始化的 LazyModules,並且它們將保持初始化狀態)>>> full_mlp = LazyMLP() >>> # Dry run to initialize another module >>> full_mlp.forward(torch.ones(10, 1)) >>> # Load an initialized state into a lazy module >>> lazy_mlp.load_state_dict(full_mlp.state_dict()) >>> # The state dict now holds valid values >>> lazy_mlp.state_dict() OrderedDict([('fc1.weight', tensor([[-0.3837], [ 0.0907], [ 0.6708], [-0.5223], [-0.9028], [ 0.2851], [-0.4537], [ 0.6813], [ 0.5766], [-0.8678]])), ('fc1.bias', tensor([-1.8832e+25, 4.5636e-41, -1.8832e+25, 4.5636e-41, -6.1598e-30, 4.5637e-41, -1.8788e+22, 4.5636e-41, -2.0042e-31, 4.5637e-41])), ('fc2.weight', tensor([[ 0.1320, 0.2938, 0.0679, 0.2793, 0.1088, -0.1795, -0.2301, 0.2807, 0.2479, 0.1091]])), ('fc2.bias', tensor([0.0019]))])
然而,請注意,如果載入狀態時引數已初始化,則在執行“空執行”時不會替換載入的引數。這可以防止在不同上下文中使用已初始化的模組。