safe_int_mm¶
- torchao.quantization.safe_int_mm(input: Tensor, mat2: Tensor) Tensor[原始碼]¶
執行安全的整數矩陣乘法,考慮了 torch.compile、cublas 和回退情況下的不同路徑。
- 引數:
input (torch.Tensor) – 輸入張量,形狀為 [i, j]。
mat2 (torch.Tensor) – 用於乘法的矩陣,形狀為 [j, k]。
- 返回值:
矩陣乘法的結果。
- 返回型別:
- 丟擲:
AssertionError – 如果張量不在同一裝置上。