捷徑

修補批次標準化

發生了什麼事?

批次標準化需要對與輸入相同大小的 running_mean 和 running_var 進行就地更新。Functorch 不支援對採用批次張量(即不允許 regular.add_(batched))的常規張量進行就地更新。因此,當對單一模組的輸入批次進行 vmap 映射時,我們會遇到此錯誤

如何修復

其中一種最佳支援的方法是將 BatchNorm 切換為 GroupNorm。選項 1 和 2 支援此方法

所有這些選項都假設您不需要執行統計資料。如果您正在使用模組,這表示假設您不會在評估模式下使用批次標準化。如果您有在評估模式下使用 vmap 執行批次標準化的用例,請提交問題

選項 1:變更 BatchNorm

如果您想變更為 GroupNorm,請在任何使用 BatchNorm 的地方將其替換為

BatchNorm2d(C, G, track_running_stats=False)

這裡的 C 與原始 BatchNorm 中的 C 相同。G 是將 C 分成幾組。因此,C % G == 0,作為後備措施,您可以設定 C == G,這表示每個通道都將單獨處理。

如果您必須使用 BatchNorm 且您自己建構了模組,則可以變更模組以不使用執行統計資料。換句話說,在任何有 BatchNorm 模組的地方,將 track_running_stats 旗標設定為 False

BatchNorm2d(64, track_running_stats=False)

選項 2:torchvision 參數

某些 torchvision 模型(例如 resnet 和 regnet)可以採用 norm_layer 參數。如果已預設,這些參數通常會預設為 BatchNorm2d。

您可以改為將其設定為 GroupNorm。

import torchvision
from functools import partial
torchvision.models.resnet18(norm_layer=lambda c: GroupNorm(num_groups=g, c))

同樣地,這裡的 c % g == 0,因此作為後備措施,請設定 g = c

如果您堅持使用 BatchNorm,請務必使用不使用執行統計資料的版本

import torchvision
from functools import partial
torchvision.models.resnet18(norm_layer=partial(BatchNorm2d, track_running_stats=False))

選項 3:functorch 的修補

functorch 新增了一些功能,允許快速就地修補模組以不使用執行統計資料。變更標準化層級比較不穩定,因此我們沒有提供該功能。如果您希望 BatchNorm 不使用執行統計資料,則可以執行 replace_all_batch_norm_modules_ 來就地更新模組以不使用執行統計資料

from torch.func import replace_all_batch_norm_modules_
replace_all_batch_norm_modules_(net)

選項 4:評估模式

在評估模式下執行時,不會更新 running_mean 和 running_var。因此,vmap 可以支援此模式

model.eval()
vmap(model)(x)
model.train()

文件

存取 PyTorch 的完整開發者文件

查看文件

教學

取得適用於初學者和進階開發者的深入教學

查看教學

資源

尋找開發資源並獲得問題解答

查看資源