快捷方式

torch.nn.utils.parametrizations.spectral_norm

torch.nn.utils.parametrizations.spectral_norm(module, name='weight', n_power_iterations=1, eps=1e-12, dim=None)[source][source]

將譜歸一化應用於給定模組中的引數。

WSN=Wσ(W),σ(W)=maxh:h0Wh2h2\mathbf{W}_{SN} = \dfrac{\mathbf{W}}{\sigma(\mathbf{W})}, \sigma(\mathbf{W}) = \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2}

當應用於向量時,它簡化為

xSN=xx2\mathbf{x}_{SN} = \dfrac{\mathbf{x}}{\|\mathbf{x}\|_2}

譜歸一化透過降低模型的 Lipschitz 常數,穩定了生成對抗網路 (GAN) 中判別器(評論家)的訓練。每次訪問權重時,透過執行一次冪迭代法來近似計算 σ\sigma。如果權重張量的維度大於 2,則在冪迭代方法中將其重塑為 2D 以獲得譜範數。

參見 用於生成對抗網路的譜歸一化

注意

此函式是使用 register_parametrization() 中的 parametrization 功能實現的。它是 torch.nn.utils.spectral_norm() 的重新實現。

注意

註冊此約束後,將估計與最大奇異值相關聯的奇異向量,而不是隨機取樣。然後,當在 training 模式下訪問模組中的張量時,透過執行 n_power_iterations冪迭代法來更新這些奇異向量。

注意

如果 _SpectralNorm 模組(即 module.parametrization.weight[idx])在移除時處於訓練模式,它將執行另一次冪迭代。如果您想避免這次迭代,請在移除前將模組設定為評估模式。

引數
  • module (nn.Module) – 包含模組

  • name (str, optional) – 權重引數名稱。預設值:"weight"

  • n_power_iterations (int, optional) – 用於計算譜範數的冪迭代次數。預設值:1

  • eps (float, optional) – 計算範數時的數值穩定性 epsilon。預設值:1e-12

  • dim (int, optional) – 對應於輸出數量的維度。預設值:0,對於 ConvTranspose{1,2,3}d 模組例項除外,此時為 1

返回值

註冊了新 parametrization 的原始模組

返回型別

Module

示例

>>> snm = spectral_norm(nn.Linear(20, 40))
>>> snm
ParametrizedLinear(
  in_features=20, out_features=40, bias=True
  (parametrizations): ModuleDict(
    (weight): ParametrizationList(
      (0): _SpectralNorm()
    )
  )
)
>>> torch.linalg.matrix_norm(snm.weight, 2)
tensor(1.0081, grad_fn=<AmaxBackward0>)

文件

訪問 PyTorch 的全面開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源