快捷方式

torch.linalg.householder_product

torch.linalg.householder_product(A, tau, *, out=None) Tensor

計算 Householder 矩陣乘積的前 n 列。

K\mathbb{K}R\mathbb{R}C\mathbb{C},設 AKm×nA \in \mathbb{K}^{m \times n} 為一個矩陣,其列向量為 aiKma_i \in \mathbb{K}^m(對於 i=1,,mi=1,\ldots,m,且 mnm \geq n)。記 bib_i 為將 aia_i 的前 i1i-1 個分量置零、第 ii 個分量設為 `1` 所得的向量。對於一個向量 τKk\tau \in \mathbb{K}^k(其中 knk \leq n),本函式計算以下矩陣的前 nn

H1H2...HkwithHi=ImτibibiHH_1H_2 ... H_k \qquad\text{with}\qquad H_i = \mathrm{I}_m - \tau_i b_i b_i^{\text{H}}

其中 Im\mathrm{I}_mm 維單位矩陣,$b^{\text{H}}$ 在 $b$ 是複數時表示共軛轉置,在 $b$ 是實數時表示轉置。輸出矩陣的大小與輸入矩陣 A 相同。

有關更多詳細資訊,請參見 Representation of Orthogonal or Unitary Matrices

支援 float、double、cfloat 和 cdouble 資料型別的輸入。也支援批處理矩陣輸入,如果輸入是批處理矩陣,則輸出具有相同的批處理維度。

另請參閱

torch.geqrf() 可與本函式結合使用,從 qr() 分解中形成 Q 矩陣。

torch.ormqr() 是一個相關函式,用於計算 Householder 矩陣乘積與另一個矩陣的矩陣乘法。但是,該函式不支援自動求導。

警告

只有當 τi1ai2\tau_i \neq \frac{1}{||a_i||^2} 時,梯度計算才是良好定義的。如果未滿足此條件,不會丟擲錯誤,但生成的梯度可能包含 NaN

引數
  • A (Tensor) – 形狀為 (*, m, n) 的張量,其中 * 表示零個或多個批處理維度。

  • tau (Tensor) – 形狀為 (*, k) 的張量,其中 * 表示零個或多個批處理維度。

關鍵字引數

out (Tensor, 可選) – 輸出張量。如果為 None 則忽略。預設值:None

丟擲

RuntimeError – 如果 A 不滿足 m >= n 的要求,或 tau 不滿足 n >= k 的要求。

示例

>>> A = torch.randn(2, 2)
>>> h, tau = torch.geqrf(A)
>>> Q = torch.linalg.householder_product(h, tau)
>>> torch.dist(Q, torch.linalg.qr(A).Q)
tensor(0.)

>>> h = torch.randn(3, 2, 2, dtype=torch.complex128)
>>> tau = torch.randn(3, 1, dtype=torch.complex128)
>>> Q = torch.linalg.householder_product(h, tau)
>>> Q
tensor([[[ 1.8034+0.4184j,  0.2588-1.0174j],
        [-0.6853+0.7953j,  2.0790+0.5620j]],

        [[ 1.4581+1.6989j, -1.5360+0.1193j],
        [ 1.3877-0.6691j,  1.3512+1.3024j]],

        [[ 1.4766+0.5783j,  0.0361+0.6587j],
        [ 0.6396+0.1612j,  1.3693+0.4481j]]], dtype=torch.complex128)

文件

訪問 PyTorch 的完整開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲取問題解答

檢視資源