torch.triangular_solve¶
- torch.triangular_solve(b, A, upper=True, transpose=False, unitriangular=False, *, out=None)¶
求解具有方形上三角或下三角可逆矩陣 和多個右側項 的方程組。
符號表示為求解 ,並假設 是方形上三角矩陣(如果
upper= False 則為下三角矩陣),且對角線上沒有零。torch.triangular_solve(b, A) 可以接受 2D 輸入 b, A 或 2D 矩陣的批次輸入。如果輸入是批次,則返回批次的輸出 X
如果
A的對角線包含零或非常接近零的元素,並且unitriangular= False(預設值),或者輸入矩陣是病態的,則結果可能包含 NaN。支援 float, double, cfloat 和 cdouble 資料型別的輸入。
警告
torch.triangular_solve()已被棄用,推薦使用torch.linalg.solve_triangular(),並將在未來的 PyTorch 版本中移除。torch.linalg.solve_triangular()的引數順序相反,並且不返回其中一個輸入的副本。X = torch.triangular_solve(B, A).solution應該替換為X = torch.linalg.solve_triangular(A, B)
- 引數
b (Tensor) – 多個右側項,形狀為 ,其中 是零個或多個批處理維度
A (Tensor) – 輸入的三角係數矩陣,形狀為 ,其中 是零個或多個批處理維度
upper (bool, optional) – 表示 是上三角矩陣還是下三角矩陣。預設值:
True。transpose (bool, optional) – 求解 op(A)X = b,其中如果此標誌為
True,則 op(A) = A^T;如果為False,則 op(A) = A。預設值:False。unitriangular (bool, optional) – 表示 是否為單位三角矩陣。如果為 True,則假設 的對角線元素為 1,並且不參考 中的值。預設值:
False。
- 關鍵字引數
out ((Tensor, Tensor), optional) – 用於寫入輸出的兩個 Tensor 的元組。如果為 None 則忽略。預設值:None。
- 返回值
一個 namedtuple (solution, cloned_coefficient),其中 cloned_coefficient 是 的副本,而 solution 是方程組 的解 (或根據關鍵字引數決定的方程組的變體)。
示例
>>> A = torch.randn(2, 2).triu() >>> A tensor([[ 1.1527, -1.0753], [ 0.0000, 0.7986]]) >>> b = torch.randn(2, 3) >>> b tensor([[-0.0210, 2.3513, -1.5492], [ 1.5429, 0.7403, -1.0243]]) >>> torch.triangular_solve(b, A) torch.return_types.triangular_solve( solution=tensor([[ 1.7841, 2.9046, -2.5405], [ 1.9320, 0.9270, -1.2826]]), cloned_coefficient=tensor([[ 1.1527, -1.0753], [ 0.0000, 0.7986]]))