SyncBatchNorm¶
- class torch.nn.SyncBatchNorm(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, process_group=None, device=None, dtype=None)[source][source]¶
對 N 維輸入應用 Batch Normalization。
N 維輸入是一個 [N-2] 維輸入的 mini-batch(帶有額外的通道維度),如論文 Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift 中所述。
均值和標準差是針對同一程序組的所有 mini-batch,按維度計算的。 和 是大小為 C(C 是輸入大小)的可學習引數向量。預設情況下, 的元素從 中取樣, 的元素設定為 0。標準差透過有偏估計器計算,等同於 torch.var(input, unbiased=False)。
同樣預設情況下,在訓練期間,此層會保留其計算出的均值和方差的執行估計值,這些估計值隨後用於評估期間的歸一化。執行估計值使用預設
momentum0.1 進行維護。如果
track_running_stats設定為False,則此層不保留執行估計值,並且在評估期間也使用 batch 統計資訊。注意
此
momentum引數與最佳化器類中使用的動量以及傳統的動量概念不同。數學上,此處的執行統計資料更新規則為 ,其中 是估計統計量,而 是新的觀測值。由於 Batch Normalization 是在
C維度中為每個通道完成的,即在(N, +)切片上計算統計資料,因此通常將此稱為 Volumetric Batch Normalization 或 Spatio-temporal Batch Normalization。目前
SyncBatchNorm僅支援每個程序一個 GPU 的DistributedDataParallel(DDP)。在用 DDP 封裝網路之前,使用torch.nn.SyncBatchNorm.convert_sync_batchnorm()將BatchNorm*D層轉換為SyncBatchNorm。- 引數
num_features (int) – 來自預期輸入大小 中的
eps (float) – 新增到分母上的值,用於數值穩定性。預設值:
1e-5momentum (Optional[float]) – 用於 running_mean 和 running_var 計算的值。可設定為
None表示累積移動平均(即簡單平均)。預設值: 0.1affine (bool) – 一個布林值,當設定為
True時,此模組具有可學習的仿射引數。預設值:Truetrack_running_stats (bool) – 一個布林值,當設定為
True時,此模組跟蹤執行均值和方差;當設定為False時,此模組不跟蹤此類統計資訊,並將統計緩衝區running_mean和running_var初始化為None。當這些緩衝區為None時,此模組在訓練和評估模式下始終使用 batch 統計資訊。預設值:Trueprocess_group (Optional[Any]) – 統計資料的同步在每個程序組內單獨進行。預設行為是在整個世界範圍內同步
- 形狀
輸入:
輸出: (形狀與輸入相同)
注意
batchnorm 統計資料的同步僅在訓練期間發生,即當
model.eval()設定為或self.training為False時,同步將被停用。示例
>>> # With Learnable Parameters >>> m = nn.SyncBatchNorm(100) >>> # creating process group (optional) >>> # ranks is a list of int identifying rank ids. >>> ranks = list(range(8)) >>> r1, r2 = ranks[:4], ranks[4:] >>> # Note: every rank calls into new_group for every >>> # process group created, even if that rank is not >>> # part of the group. >>> process_groups = [torch.distributed.new_group(pids) for pids in [r1, r2]] >>> process_group = process_groups[0 if dist.get_rank() <= 3 else 1] >>> # Without Learnable Parameters >>> m = nn.BatchNorm3d(100, affine=False, process_group=process_group) >>> input = torch.randn(20, 100, 35, 45, 10) >>> output = m(input) >>> # network is nn.BatchNorm layer >>> sync_bn_network = nn.SyncBatchNorm.convert_sync_batchnorm(network, process_group) >>> # only single gpu per process is currently supported >>> ddp_sync_bn_network = torch.nn.parallel.DistributedDataParallel( >>> sync_bn_network, >>> device_ids=[args.local_rank], >>> output_device=args.local_rank)
- classmethod convert_sync_batchnorm(module, process_group=None)[source][source]¶
將模型中的所有
BatchNorm*D層轉換為torch.nn.SyncBatchNorm層。- 引數
module (nn.Module) – 包含一個或多個
BatchNorm*D層的模組process_group (optional) – 用於確定同步範圍的程序組,預設為整個世界
- 返回值
轉換後的
torch.nn.SyncBatchNorm層的原始module。如果原始module是BatchNorm*D層,則將返回一個新的torch.nn.SyncBatchNorm層物件。
示例
>>> # Network with nn.BatchNorm layer >>> module = torch.nn.Sequential( >>> torch.nn.Linear(20, 100), >>> torch.nn.BatchNorm1d(100), >>> ).cuda() >>> # creating process group (optional) >>> # ranks is a list of int identifying rank ids. >>> ranks = list(range(8)) >>> r1, r2 = ranks[:4], ranks[4:] >>> # Note: every rank calls into new_group for every >>> # process group created, even if that rank is not >>> # part of the group. >>> process_groups = [torch.distributed.new_group(pids) for pids in [r1, r2]] >>> process_group = process_groups[0 if dist.get_rank() <= 3 else 1] >>> sync_bn_module = torch.nn.SyncBatchNorm.convert_sync_batchnorm(module, process_group)