快捷方式

torch.lu

torch.lu(*args, **kwargs)[原始碼]

計算矩陣或批次矩陣 A 的 LU 分解。返回一個包含 A 的 LU 分解和主元的元組。如果 pivot 設定為 True,則執行主元選擇。

警告

torch.lu() 已被 torch.linalg.lu_factor()torch.linalg.lu_factor_ex() 棄用。torch.lu() 將在未來的 PyTorch 版本中移除。LU, pivots, info = torch.lu(A, compute_pivots) 應替換為

LU, pivots = torch.linalg.lu_factor(A, compute_pivots)

LU, pivots, info = torch.lu(A, compute_pivots, get_infos=True) 應替換為

LU, pivots, info = torch.linalg.lu_factor_ex(A, compute_pivots)

注意

  • 批次中每個矩陣返回的置換矩陣由一個 1-indexed (起始索引為 1) 的向量表示,其大小為 min(A.shape[-2], A.shape[-1])pivots[i] == j 表示在演算法的第 i 步,第 i 行與第 j-1 行進行了置換。

  • 對於 CPU,不支援 pivot = False 的 LU 分解,嘗試這樣做將丟擲錯誤。但是,對於 CUDA,支援 pivot = False 的 LU 分解。

  • 如果 get_infosTrue,此函式不會檢查分解是否成功,因為分解的狀態存在於返回元組的第三個元素中。

  • 對於 CUDA 裝置上大小小於或等於 32 的批次方陣,由於 MAGMA 庫中的錯誤 (參見 magma issue 13),會為奇異矩陣重複執行 LU 分解。

  • LUP 可以使用 torch.lu_unpack() 派生得到。

警告

此函式的梯度僅在 A 是滿秩時為有限值。這是因為 LU 分解僅在滿秩矩陣處可微。此外,如果 A 接近非滿秩,則梯度將由於依賴於 L1L^{-1}U1U^{-1} 的計算而導致數值不穩定。

引數
  • A (Tensor) – 要分解的 tensor,大小為 (,m,n)(*, m, n)

  • pivot (bool, 可選) – 控制是否進行主元選擇。預設值:True

  • get_infos (bool, 可選) – 如果設定為 True,則返回一個 info IntTensor。預設值:False

  • out (tuple, 可選) – 可選的輸出元組。如果 get_infosTrue,則元組中的元素為 Tensor、IntTensor 和 IntTensor。如果 get_infosFalse,則元組中的元素為 Tensor、IntTensor。預設值:None

返回

一個 tensor 元組,包含

  • factorization (Tensor):分解結果,大小為 (,m,n)(*, m, n)

  • pivots (IntTensor):主元,大小為 (,min(m,n))(*, \text{min}(m, n))pivots 儲存了所有中間的行轉置。可以透過對 i = 0, ..., pivots.size(-1) - 1 應用 swap(perm[i], perm[pivots[i] - 1]) 來重構最終的置換 perm,其中 perm 最初是 mm 個元素的單位置換(這與 torch.lu_unpack() 所做的事情基本相同)。

  • infos (IntTensor, 可選):如果 get_infosTrue,這是一個大小為 ()(*) 的 tensor,其中非零值表示矩陣或每個 mini-batch 的分解是成功還是失敗。

返回型別

(Tensor, IntTensor, IntTensor (可選))

示例

>>> A = torch.randn(2, 3, 3)
>>> A_LU, pivots = torch.lu(A)
>>> A_LU
tensor([[[ 1.3506,  2.5558, -0.0816],
         [ 0.1684,  1.1551,  0.1940],
         [ 0.1193,  0.6189, -0.5497]],

        [[ 0.4526,  1.2526, -0.3285],
         [-0.7988,  0.7175, -0.9701],
         [ 0.2634, -0.9255, -0.3459]]])
>>> pivots
tensor([[ 3,  3,  3],
        [ 3,  3,  3]], dtype=torch.int32)
>>> A_LU, pivots, info = torch.lu(A, get_infos=True)
>>> if info.nonzero().size(0) == 0:
...     print('LU factorization succeeded for all samples!')
LU factorization succeeded for all samples!

© 版權所有 PyTorch 貢獻者。

使用 Sphinx 構建,主題由 Read the Docs 提供。

文件

訪問 PyTorch 的全面開發者文件

檢視文件

教程

獲取針對初學者和高階開發者的深入教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源