快捷方式

torch.nn.utils.parametrizations.orthogonal

torch.nn.utils.parametrizations.orthogonal(module, name='weight', orthogonal_map=None, *, use_trivialization=True)[原始碼][原始碼]

將正交或酉引數化應用於矩陣或批次矩陣。

K\mathbb{K}R\mathbb{R}C\mathbb{C},引數化後的矩陣 QKm×nQ \in \mathbb{K}^{m \times n}正交的,定義如下:

QHQ=Inif mnQQH=Imif m<n\begin{align*} Q^{\text{H}}Q &= \mathrm{I}_n \mathrlap{\qquad \text{if }m \geq n}\\ QQ^{\text{H}} &= \mathrm{I}_m \mathrlap{\qquad \text{if }m < n} \end{align*}

其中 QHQ^{\text{H}} 是當 QQ 是複數時為共軛轉置,當 QQ 是實數時為轉置,In\mathrm{I}_nn 維單位矩陣。簡而言之,當 mnm \geq n 時,QQ 將具有正交列,否則具有正交行。

如果張量具有兩個以上的維度,我們將其視為形狀為 (…, m, n) 的批次矩陣。

矩陣 QQ 可以透過三個不同的 orthogonal_map(相對於原始張量)進行引數化:

  • "matrix_exp"/"cayley": 將 matrix_exp() Q=exp(A)Q = \exp(A)Cayley 對映 Q=(In+A/2)(InA/2)1Q = (\mathrm{I}_n + A/2)(\mathrm{I}_n - A/2)^{-1} 應用於一個斜對稱矩陣 AA 以得到一個正交矩陣。

  • "householder": 計算 Householder 反射積 (householder_product())。

"matrix_exp"/"cayley" 通常使引數化後的權重比 "householder" 收斂更快,但對於非常“瘦”或非常“寬”的矩陣,它們的計算速度較慢。

如果 use_trivialization=True(預設值),則引數化實現了“動態平凡化框架”(Dynamic Trivialization Framework),其中一個額外的矩陣 BKn×nB \in \mathbb{K}^{n \times n} 儲存在 module.parametrizations.weight[0].base 下。這有助於引數化層的收斂,但會消耗一些額外的記憶體。請參閱 Trivializations for Gradient-Based Optimization on Manifolds

QQ 的初始值:如果原始張量未被引數化且 use_trivialization=True(預設值),則 QQ 的初始值如果原始張量本身是正交的(或在複數情況下是酉的),則使用原始張量的值;否則,透過 QR 分解進行正交化(參見 torch.linalg.qr())。當未引數化且 orthogonal_map="householder" 時,即使 use_trivialization=False,也會發生同樣的情況。否則,初始值是應用於原始張量的所有已註冊引數化組合的結果。

注意

此函式使用 register_parametrization() 中的引數化功能實現。

引數
  • module (nn.Module) – 要在其上註冊引數化的模組。

  • name (str, 可選) – 要進行正交化的張量名稱。預設值:"weight"

  • orthogonal_map (str, 可選) – 以下之一:"matrix_exp""cayley""householder"。預設值:如果矩陣是方陣或複數,則為 "matrix_exp";否則為 "householder"

  • use_trivialization (bool, 可選) – 是否使用動態平凡化框架。預設值:True

返回值

已將正交引數化註冊到指定權重的原始模組

返回型別

Module

示例

>>> orth_linear = orthogonal(nn.Linear(20, 40))
>>> orth_linear
ParametrizedLinear(
in_features=20, out_features=40, bias=True
(parametrizations): ModuleDict(
    (weight): ParametrizationList(
    (0): _Orthogonal()
    )
)
)
>>> Q = orth_linear.weight
>>> torch.dist(Q.T @ Q, torch.eye(20))
tensor(4.9332e-07)

文件

查閱 PyTorch 的全面開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源