快捷方式

torch.linalg.multi_dot

torch.linalg.multi_dot(tensors, *, out=None)

透過重新排序乘法操作,有效率地乘以兩個或多個矩陣,以執行最少次數的算術運算。

支援 float, double, cfloat 和 cdouble 資料型別的輸入。此函式不支援批處理輸入。

tensors 中的每個 tensor 都必須是 2D 的,除了第一個和最後一個可以是 1D 的。如果第一個 tensor 是形狀為 (n,) 的 1D 向量,它將被視為形狀為 (1, n) 的行向量;類似地,如果最後一個 tensor 是形狀為 (n,) 的 1D 向量,它將被視為形狀為 (n, 1) 的列向量。

如果第一個和最後一個 tensor 都是矩陣,則輸出將是矩陣。然而,如果其中任何一個是一維向量,則輸出將是一維向量。

numpy.linalg.multi_dot 的區別

  • numpy.linalg.multi_dot 不同,第一個和最後一個 tensor 必須是 1D 或 2D,而 NumPy 允許它們是 nD。

警告

此函式不支援廣播。

注意

此函式透過計算最優矩陣乘法順序後,鏈式呼叫 torch.mm() 來實現。

注意

兩個形狀分別為 (a, b)(b, c) 的矩陣相乘的代價是 a * b * c。給定形狀分別為 (10, 100)(100, 5)(5, 50) 的矩陣 ABC,我們可以按如下方式計算不同乘法順序的代價:

cost((AB)C)=10×100×5+10×5×50=7500cost(A(BC))=10×100×50+100×5×50=75000\begin{align*} \operatorname{cost}((AB)C) &= 10 \times 100 \times 5 + 10 \times 5 \times 50 = 7500 \\ \operatorname{cost}(A(BC)) &= 10 \times 100 \times 50 + 100 \times 5 \times 50 = 75000 \end{align*}

在這種情況下,先將 A 和 B 相乘,再乘以 C,速度快 10 倍。

引數

tensors (Sequence[Tensor]) – 要相乘的兩個或多個 tensor。第一個和最後一個 tensor 可以是 1D 或 2D。所有其他 tensor 必須是 2D 的。

關鍵字引數

out (Tensor, optional) – 輸出 tensor。如果為 None 則忽略。預設值: None

示例

>>> from torch.linalg import multi_dot

>>> multi_dot([torch.tensor([1, 2]), torch.tensor([2, 3])])
tensor(8)
>>> multi_dot([torch.tensor([[1, 2]]), torch.tensor([2, 3])])
tensor([8])
>>> multi_dot([torch.tensor([[1, 2]]), torch.tensor([[2], [3]])])
tensor([[8]])

>>> A = torch.arange(2 * 3).view(2, 3)
>>> B = torch.arange(3 * 2).view(3, 2)
>>> C = torch.arange(2 * 2).view(2, 2)
>>> multi_dot((A, B, C))
tensor([[ 26,  49],
        [ 80, 148]])

文件

訪問 PyTorch 的全面開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源