torch.func.stack_module_state¶
- torch.func.stack_module_state(models) params, buffers[原始碼]¶
準備一個 nn.Module 列表,以便使用
vmap()進行整合。給定 M 個相同類的
nn.Modules列表,返回兩個字典,它們按名稱索引,將所有引數和緩衝區堆疊在一起。堆疊的引數是可最佳化的(即,它們是 autograd 歷史中的新葉節點,與原始引數無關,可以直接傳遞給最佳化器)。下面是一個如何對非常簡單的模型進行整合的示例:
num_models = 5 batch_size = 64 in_features, out_features = 3, 3 models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)] data = torch.randn(batch_size, 3) def wrapper(params, buffers, data): return torch.func.functional_call(models[0], (params, buffers), data) params, buffers = stack_module_state(models) output = vmap(wrapper, (0, 0, None))(params, buffers, data) assert output.shape == (num_models, batch_size, out_features)
當存在子模組時,這遵循狀態字典命名約定。
import torch.nn as nn class Foo(nn.Module): def __init__(self, in_features, out_features): super().__init__() hidden = 4 self.l1 = nn.Linear(in_features, hidden) self.l2 = nn.Linear(hidden, out_features) def forward(self, x): return self.l2(self.l1(x)) num_models = 5 in_features, out_features = 3, 3 models = [Foo(in_features, out_features) for i in range(num_models)] params, buffers = stack_module_state(models) print(list(params.keys())) # "l1.weight", "l1.bias", "l2.weight", "l2.bias"
警告
所有要堆疊在一起的模組必須相同(引數/緩衝區的值除外)。例如,它們應處於相同的模式(訓練模式 vs 評估模式)。