torch.nn.utils.parametrizations.spectral_norm¶
- torch.nn.utils.parametrizations.spectral_norm(module, name='weight', n_power_iterations=1, eps=1e-12, dim=None)[source][source]¶
將譜歸一化應用於給定模組中的引數。
當應用於向量時,它簡化為
譜歸一化透過降低模型的 Lipschitz 常數,穩定了生成對抗網路 (GAN) 中判別器(評論家)的訓練。每次訪問權重時,透過執行一次冪迭代法來近似計算 。如果權重張量的維度大於 2,則在冪迭代方法中將其重塑為 2D 以獲得譜範數。
參見 用於生成對抗網路的譜歸一化 。
注意
此函式是使用
register_parametrization()中的 parametrization 功能實現的。它是torch.nn.utils.spectral_norm()的重新實現。注意
註冊此約束後,將估計與最大奇異值相關聯的奇異向量,而不是隨機取樣。然後,當在 training 模式下訪問模組中的張量時,透過執行
n_power_iterations次冪迭代法來更新這些奇異向量。注意
如果 _SpectralNorm 模組(即 module.parametrization.weight[idx])在移除時處於訓練模式,它將執行另一次冪迭代。如果您想避免這次迭代,請在移除前將模組設定為評估模式。
- 引數
- 返回值
註冊了新 parametrization 的原始模組
- 返回型別
示例
>>> snm = spectral_norm(nn.Linear(20, 40)) >>> snm ParametrizedLinear( in_features=20, out_features=40, bias=True (parametrizations): ModuleDict( (weight): ParametrizationList( (0): _SpectralNorm() ) ) ) >>> torch.linalg.matrix_norm(snm.weight, 2) tensor(1.0081, grad_fn=<AmaxBackward0>)