修補 Batch Norm¶
發生了什麼?¶
Batch Norm 需要對與輸入具有相同尺寸的 running_mean 和 running_var 進行原地更新。Functorch 不支援對接收批處理張量(即 regular.add_(batched) 不允許)的常規張量進行原地更新。因此,當對單個模組的輸入批次進行 vmapping 時,我們會遇到此錯誤
如何修復¶
一種最佳支援的方式是將 BatchNorm 替換為 GroupNorm。選項 1 和 2 支援此方法
所有這些選項都假設您不需要執行統計量(running stats)。如果您正在使用模組,這意味著假設您不會在評估模式下使用 BatchNorm。如果您需要在評估模式下使用 vmap 執行 BatchNorm,請提交一個 issue
選項 1:更改 BatchNorm¶
如果您想將其更改為 GroupNorm,請將所有 BatchNorm 替換為
BatchNorm2d(C, G, track_running_stats=False)
這裡 C 與原始 BatchNorm 中的 C 相同。G 是將 C 劃分成的組數。因此,C % G == 0 成立,作為備用方案,您可以設定 C == G,這意味著每個通道都將被單獨處理。
如果您必須使用 BatchNorm 並且是自己構建的模組,則可以更改模組以不使用執行統計量。換句話說,在任何 BatchNorm 模組處,將 track_running_stats 標誌設定為 False
BatchNorm2d(64, track_running_stats=False)
選項 2:torchvision 引數¶
一些 torchvision 模型,如 resnet 和 regnet,可以接受 norm_layer 引數。如果使用了預設值,這些引數通常會預設為 BatchNorm2d。
您可以將其設定為 GroupNorm。
import torchvision
from functools import partial
torchvision.models.resnet18(norm_layer=lambda c: GroupNorm(num_groups=g, c))
這裡,同樣,c % g == 0 成立,因此作為備用方案,設定 g = c。
如果您堅持使用 BatchNorm,請務必使用不使用執行統計量的版本
import torchvision
from functools import partial
torchvision.models.resnet18(norm_layer=partial(BatchNorm2d, track_running_stats=False))
選項 3:functorch 的修補¶
functorch 添加了一些功能,可以快速、原地修補模組以不使用執行統計量。更改規範層更脆弱,因此我們沒有提供該選項。如果您有一個網路,希望其中的 BatchNorm 不使用執行統計量,可以執行 replace_all_batch_norm_modules_ 來原地更新模組以不使用執行統計量。
from torch.func import replace_all_batch_norm_modules_
replace_all_batch_norm_modules_(net)
選項 4:評估模式¶
在評估模式下執行時,running_mean 和 running_var 將不會更新。因此,vmap 可以支援此模式。
model.eval()
vmap(model)(x)
model.train()