RMSNorm¶
- class torch.nn.modules.normalization.RMSNorm(normalized_shape, eps=None, elementwise_affine=True, device=None, dtype=None)[source][source]¶
對輸入 mini-batch 應用均方根層歸一化。
此層實現了論文 Root Mean Square Layer Normalization 中描述的操作
均方根 (RMS) 計算是基於最後
D個維度進行的,其中D是normalized_shape的維度。例如,如果normalized_shape為(3, 5)(一個二維形狀),則均方根計算將應用於輸入的最後 2 個維度。- 引數
- 形狀
輸入:
輸出: (與輸入形狀相同)
示例
>>> rms_norm = nn.RMSNorm([2, 3]) >>> input = torch.randn(2, 2, 3) >>> rms_norm(input)