torch.linalg.lstsq¶
- torch.linalg.lstsq(A, B, rcond=None, *, driver=None)¶
計算線性方程組最小二乘問題的解。
令 為 或 ,線性系統 (其中 )的**最小二乘問題**定義為
其中 表示 Frobenius 範數。
支援 float, double, cfloat 和 cdouble 資料型別的輸入。也支援矩陣的批次輸入,如果輸入是批次矩陣,則輸出具有相同的批次維度。
driver選擇將使用的後端函式。對於 CPU 輸入,有效值包括 ‘gels’, ‘gelsy’, ‘gelsd, ‘gelss’。選擇最佳 CPU 驅動時考慮:如果
A條件良好(其條件數不太大),或者您不介意一些精度損失。對於一般矩陣:‘gelsy’ (帶主元旋轉的 QR) (預設)
如果
A是滿秩矩陣:‘gels’ (QR)
如果
A條件不好。‘gelsd’ (三對角化約簡和 SVD)
如果遇到記憶體問題:‘gelss’ (完全 SVD)。
對於 CUDA 輸入,唯一有效的驅動是 ‘gels’,它假定
A是滿秩矩陣。另請參閱這些驅動的完整描述
rcond用於確定當driver是 (‘gelsy’, ‘gelsd’, ‘gelss’) 之一時,A中矩陣的有效秩。在此情況下,如果 是按降序排列的 A 的奇異值,則當 時, 將被向下舍入為零。如果rcond= None (預設),則rcond將設定為A的 dtype 的機器精度乘以 max(m, n)。此函式以包含四個 Tensor 的命名元組 (solution, residuals, rank, singular_values) 形式返回問題的解和一些額外資訊。對於形狀分別為 (*, m, n) 和 (*, m, k) 的輸入
A、B,它包含:solution: 最小二乘解。其形狀為 (*, n, k)。
residuals: 解的平方殘差,即 。其形狀為 (*, k)。當 m > n 且
A中的每個矩陣都是滿秩時計算此值,否則返回空 Tensor。如果A是批次矩陣,且批次中任何矩陣不是滿秩,則返回空 Tensor。此行為在未來的 PyTorch 版本中可能會改變。rank:
A中矩陣的秩 Tensor。其形狀與A的批次維度相同。當driver是 (‘gelsy’, ‘gelsd’, ‘gelss’) 之一時計算此值,否則返回空 Tensor。singular_values:
A中矩陣的奇異值 Tensor。其形狀為 (*, min(m, n))。當driver是 (‘gelsd’, ‘gelss’) 之一時計算此值,否則返回空 Tensor。
注意
此函式以比單獨計算更快且數值更穩定的方式計算 X =
A.pinverse() @B。警告
rcond的預設值在未來的 PyTorch 版本中可能會改變。因此,建議使用固定值以避免潛在的破壞性更改。- 引數
- 關鍵字引數
driver (str, optional) – 將使用的 LAPACK/MAGMA 方法的名稱。如果 None,CPU 輸入使用 ‘gelsy’,CUDA 輸入使用 ‘gels’。預設值: None。
- 返回值
一個命名元組 (solution, residuals, rank, singular_values)。
示例
>>> A = torch.randn(1,3,3) >>> A tensor([[[-1.0838, 0.0225, 0.2275], [ 0.2438, 0.3844, 0.5499], [ 0.1175, -0.9102, 2.0870]]]) >>> B = torch.randn(2,3,3) >>> B tensor([[[-0.6772, 0.7758, 0.5109], [-1.4382, 1.3769, 1.1818], [-0.3450, 0.0806, 0.3967]], [[-1.3994, -0.1521, -0.1473], [ 1.9194, 1.0458, 0.6705], [-1.1802, -0.9796, 1.4086]]]) >>> X = torch.linalg.lstsq(A, B).solution # A is broadcasted to shape (2, 3, 3) >>> torch.dist(X, torch.linalg.pinv(A) @ B) tensor(1.5152e-06) >>> S = torch.linalg.lstsq(A, B, driver='gelsd').singular_values >>> torch.dist(S, torch.linalg.svdvals(A)) tensor(2.3842e-07) >>> A[:, 0].zero_() # Decrease the rank of A >>> rank = torch.linalg.lstsq(A, B).rank >>> rank tensor([2])