GroupNorm¶
- class torch.nn.GroupNorm(num_groups, num_channels, eps=1e-05, affine=True, device=None, dtype=None)[source][source]¶
對輸入的小批次資料應用 Group Normalization。
此層實現論文 Group Normalization 中描述的操作。
輸入通道被分成
num_groups組,每組包含num_channels / num_groups個通道。num_channels必須能被num_groups整除。均值和標準差分別在每組內計算。 和 是按通道學習的仿射變換引數向量,如果affine為True,它們的大小為num_channels。方差是使用有偏估計器計算的,相當於 torch.var(input, unbiased=False)。此層在訓練和評估模式下都使用從輸入資料計算的統計資訊。
- 引數
- 形狀
輸入:,其中
輸出: (形狀與輸入相同)
示例
>>> input = torch.randn(20, 6, 10, 10) >>> # Separate 6 channels into 3 groups >>> m = nn.GroupNorm(3, 6) >>> # Separate 6 channels into 6 groups (equivalent with InstanceNorm) >>> m = nn.GroupNorm(6, 6) >>> # Put all 6 channels into a single group (equivalent with LayerNorm) >>> m = nn.GroupNorm(1, 6) >>> # Activating the module >>> output = m(input)