快捷方式

torch.bmm

torch.bmm(input, mat2, *, out=None) Tensor

執行儲存在 inputmat2 中的矩陣批次矩陣乘法。

inputmat2 必須是 3 維張量,且包含相同數量的矩陣。

如果 input 是一個 (b×n×m)(b \times n \times m) 張量,mat2 是一個 (b×m×p)(b \times m \times p) 張量,則 out 將是一個 (b×n×p)(b \times n \times p) 張量。

outi=inputi@mat2i\text{out}_i = \text{input}_i \mathbin{@} \text{mat2}_i

此運算元支援 TensorFloat32

在某些 ROCm 裝置上,當使用 float16 輸入時,此模組在反向傳播時將使用不同的精度

注意

此函式不支援廣播。對於廣播矩陣乘法,請參閱torch.matmul()

引數
  • input (Tensor) – 要相乘的第一批矩陣

  • mat2 (Tensor) – 要相乘的第二批矩陣

關鍵字引數

out (Tensor, optional) – 輸出張量。

示例

>>> input = torch.randn(10, 3, 4)
>>> mat2 = torch.randn(10, 4, 5)
>>> res = torch.bmm(input, mat2)
>>> res.size()
torch.Size([10, 3, 5])

文件

訪問 PyTorch 的全面開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深入教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源