修補批次標準化¶
發生了什麼事?¶
批次標準化需要對與輸入相同大小的 running_mean 和 running_var 進行就地更新。Functorch 不支援對採用批次張量(即不允許 regular.add_(batched))的常規張量進行就地更新。因此,當對單一模組的輸入批次進行 vmap 映射時,我們會遇到此錯誤
如何修復¶
其中一種最佳支援的方法是將 BatchNorm 切換為 GroupNorm。選項 1 和 2 支援此方法
所有這些選項都假設您不需要執行統計資料。如果您正在使用模組,這表示假設您不會在評估模式下使用批次標準化。如果您有在評估模式下使用 vmap 執行批次標準化的用例,請提交問題
選項 1:變更 BatchNorm¶
如果您想變更為 GroupNorm,請在任何使用 BatchNorm 的地方將其替換為
BatchNorm2d(C, G, track_running_stats=False)
這裡的 C 與原始 BatchNorm 中的 C 相同。G 是將 C 分成幾組。因此,C % G == 0,作為後備措施,您可以設定 C == G,這表示每個通道都將單獨處理。
如果您必須使用 BatchNorm 且您自己建構了模組,則可以變更模組以不使用執行統計資料。換句話說,在任何有 BatchNorm 模組的地方,將 track_running_stats 旗標設定為 False
BatchNorm2d(64, track_running_stats=False)
選項 2:torchvision 參數¶
某些 torchvision 模型(例如 resnet 和 regnet)可以採用 norm_layer 參數。如果已預設,這些參數通常會預設為 BatchNorm2d。
您可以改為將其設定為 GroupNorm。
import torchvision
from functools import partial
torchvision.models.resnet18(norm_layer=lambda c: GroupNorm(num_groups=g, c))
同樣地,這裡的 c % g == 0,因此作為後備措施,請設定 g = c。
如果您堅持使用 BatchNorm,請務必使用不使用執行統計資料的版本
import torchvision
from functools import partial
torchvision.models.resnet18(norm_layer=partial(BatchNorm2d, track_running_stats=False))
選項 3:functorch 的修補¶
functorch 新增了一些功能,允許快速就地修補模組以不使用執行統計資料。變更標準化層級比較不穩定,因此我們沒有提供該功能。如果您希望 BatchNorm 不使用執行統計資料,則可以執行 replace_all_batch_norm_modules_ 來就地更新模組以不使用執行統計資料
from torch.func import replace_all_batch_norm_modules_
replace_all_batch_norm_modules_(net)
選項 4:評估模式¶
在評估模式下執行時,不會更新 running_mean 和 running_var。因此,vmap 可以支援此模式
model.eval()
vmap(model)(x)
model.train()