torch.linalg.tensorinv¶
- torch.linalg.tensorinv(A, ind=2, *, out=None) Tensor¶
計算
torch.tensordot()的乘法逆。如果 m 是
A的前ind個維度的乘積,n 是其餘維度的乘積,則此函式期望 m 和 n 相等。如果滿足此條件,它將計算一個 tensor X,使得 tensordot(A, X,ind) 是維度 m 上的單位矩陣。X 的形狀將與A相同,但前ind個維度將被移到末尾。X.shape == A.shape[ind:] + A.shape[:ind]
支援 float, double, cfloat 和 cdouble 資料型別的輸入。
注意
當
A是一個 2 維 tensor 且ind= 1 時,此函式計算A的(乘法)逆(參見torch.linalg.inv())。注意
如果可能,請考慮使用
torch.linalg.tensorsolve()來計算 tensor 逆與 tensor 的左乘,因為linalg.tensorsolve(A, B) == torch.tensordot(linalg.tensorinv(A), B) # When B is a tensor with shape A.shape[:B.ndim]
在可能的情況下,總是優先使用
tensorsolve(),因為它比顯式計算偽逆更快且數值更穩定。另請參閱
torch.linalg.tensorsolve()計算 torch.tensordot(tensorinv(A),B)。- 引數
A (Tensor) – 要取逆的 tensor。其形狀必須滿足 prod(
A.shape[:ind]) == prod(A.shape[ind:])。ind (int) – 計算
torch.tensordot()的逆的索引。預設值:2。
- 關鍵字引數
out (Tensor, optional) – 輸出 tensor。如果為 None 則忽略。預設值:None。
- 丟擲
RuntimeError – 如果重塑後的
A不可逆,或者前ind個維度的乘積與其餘維度的乘積不相等。
示例
>>> A = torch.eye(4 * 6).reshape((4, 6, 8, 3)) >>> Ainv = torch.linalg.tensorinv(A, ind=2) >>> Ainv.shape torch.Size([8, 3, 4, 6]) >>> B = torch.randn(4, 6) >>> torch.allclose(torch.tensordot(Ainv, B), torch.linalg.tensorsolve(A, B)) True >>> A = torch.randn(4, 4) >>> Atensorinv = torch.linalg.tensorinv(A, ind=1) >>> Ainv = torch.linalg.inv(A) >>> torch.allclose(Atensorinv, Ainv) True