torch.bmm¶
- torch.bmm(input, mat2, *, out=None) Tensor¶
執行儲存在
input和mat2中的矩陣批次矩陣乘法。input和mat2必須是 3 維張量,且包含相同數量的矩陣。如果
input是一個 張量,mat2是一個 張量,則out將是一個 張量。此運算元支援 TensorFloat32。
在某些 ROCm 裝置上,當使用 float16 輸入時,此模組在反向傳播時將使用不同的精度。
注意
此函式不支援廣播。對於廣播矩陣乘法,請參閱
torch.matmul()。示例
>>> input = torch.randn(10, 3, 4) >>> mat2 = torch.randn(10, 4, 5) >>> res = torch.bmm(input, mat2) >>> res.size() torch.Size([10, 3, 5])