torch.baddbmm¶
- torch.baddbmm(input, batch1, batch2, *, beta=1, alpha=1, out=None) Tensor¶
執行
batch1和batch2中矩陣的批次矩陣乘法。結果中新增input。batch1和batch2必須是 3-D 張量,且每個張量包含相同數量的矩陣。如果
batch1是一個 張量,batch2是一個 張量,則input必須能夠 廣播 到一個 張量,並且out將是一個 張量。alpha和beta的含義與torch.addbmm()中使用的縮放因子相同。如果
beta為 0,則忽略input的內容,其中的nan和inf將不會傳播。對於
FloatTensor或DoubleTensor型別的輸入,引數beta和alpha必須是實數,否則應為整數。此運算元支援 TensorFloat32。
在某些 ROCm 裝置上,使用 float16 輸入時,此模組在反向傳播時會使用 不同的精度。
- 引數
- 關鍵字引數
beta (Number, optional) –
input的乘數 ()alpha (Number, optional) – 的乘數 ()
out (Tensor, optional) – 輸出張量。
示例
>>> M = torch.randn(10, 3, 5) >>> batch1 = torch.randn(10, 3, 4) >>> batch2 = torch.randn(10, 4, 5) >>> torch.baddbmm(M, batch1, batch2).size() torch.Size([10, 3, 5])