快捷方式

torch.mm

torch.mm(input, mat2, *, out=None) 張量

對矩陣 inputmat2 執行矩陣乘法。

如果 input 是一個 (n×m)(n \times m) 張量,mat2 是一個 (m×p)(m \times p) 張量,則 out 將是一個 (n×p)(n \times p) 張量。

注意

此函式不支援廣播。對於廣播矩陣乘積,請參閱torch.matmul()

支援步長和稀疏的二維張量作為輸入,以及對步長輸入的自動微分。

此操作支援具有稀疏佈局的引數。如果提供了 out,將使用其佈局。否則,結果佈局將從 input 的佈局推導。

警告

稀疏支援是一項 Beta 特性,某些佈局/資料型別/裝置組合可能不受支援,或者可能沒有自動微分支援。如果您發現缺少功能,請提出功能請求。

此運算元支援TensorFloat32

在某些 ROCm 裝置上,使用 float16 輸入時,此模組將使用不同的精度進行反向傳播。

引數
  • input (張量) – 要進行矩陣乘法的第一個矩陣

  • mat2 (張量) – 要進行矩陣乘法的第二個矩陣

關鍵字引數

out (張量,可選) – 輸出張量。

示例

>>> mat1 = torch.randn(2, 3)
>>> mat2 = torch.randn(3, 3)
>>> torch.mm(mat1, mat2)
tensor([[ 0.4851,  0.5037, -0.3633],
        [-0.0760, -3.6705,  2.4784]])

文件

訪問 PyTorch 的全面開發者文件

檢視文件

教程

獲取針對初學者和高階開發者的深入教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源