torch.linalg.householder_product¶
- torch.linalg.householder_product(A, tau, *, out=None) Tensor¶
計算 Householder 矩陣乘積的前 n 列。
設 為 或 ,設 為一個矩陣,其列向量為 (對於 ,且 )。記 為將 的前 個分量置零、第 個分量設為 `1` 所得的向量。對於一個向量 (其中 ),本函式計算以下矩陣的前 列
其中 是 m 維單位矩陣,$b^{\text{H}}$ 在 $b$ 是複數時表示共軛轉置,在 $b$ 是實數時表示轉置。輸出矩陣的大小與輸入矩陣
A相同。有關更多詳細資訊,請參見 Representation of Orthogonal or Unitary Matrices。
支援 float、double、cfloat 和 cdouble 資料型別的輸入。也支援批處理矩陣輸入,如果輸入是批處理矩陣,則輸出具有相同的批處理維度。
另請參閱
torch.geqrf()可與本函式結合使用,從qr()分解中形成 Q 矩陣。torch.ormqr()是一個相關函式,用於計算 Householder 矩陣乘積與另一個矩陣的矩陣乘法。但是,該函式不支援自動求導。警告
只有當 時,梯度計算才是良好定義的。如果未滿足此條件,不會丟擲錯誤,但生成的梯度可能包含 NaN。
- 引數
- 關鍵字引數
out (Tensor, 可選) – 輸出張量。如果為 None 則忽略。預設值:None。
- 丟擲
RuntimeError – 如果
A不滿足 m >= n 的要求,或tau不滿足 n >= k 的要求。
示例
>>> A = torch.randn(2, 2) >>> h, tau = torch.geqrf(A) >>> Q = torch.linalg.householder_product(h, tau) >>> torch.dist(Q, torch.linalg.qr(A).Q) tensor(0.) >>> h = torch.randn(3, 2, 2, dtype=torch.complex128) >>> tau = torch.randn(3, 1, dtype=torch.complex128) >>> Q = torch.linalg.householder_product(h, tau) >>> Q tensor([[[ 1.8034+0.4184j, 0.2588-1.0174j], [-0.6853+0.7953j, 2.0790+0.5620j]], [[ 1.4581+1.6989j, -1.5360+0.1193j], [ 1.3877-0.6691j, 1.3512+1.3024j]], [[ 1.4766+0.5783j, 0.0361+0.6587j], [ 0.6396+0.1612j, 1.3693+0.4481j]]], dtype=torch.complex128)