torch.sparse.mm¶
- torch.sparse.mm()¶
對稀疏矩陣
mat1和(稀疏或跨步)矩陣mat2執行矩陣乘法。類似於torch.mm(),如果mat1是一個 張量,mat2是一個 張量,則輸出將是一個 張量。當mat1是 COO 張量時,它必須具有 sparse_dim = 2。當輸入為 COO 張量時,此函式也支援對兩個輸入進行反向傳播。支援 CSR 和 COO 儲存格式。
注意
此函式不支援計算相對於 CSR 矩陣的導數。
此函式還額外接受一個可選的
reduce引數,允許指定可選的歸約(reduction)操作,數學上執行以下運算:其中 定義了歸約運算子。
reduce僅在 CPU 裝置上支援 CSR 儲存格式。- 引數
- 形狀
此函式的輸出張量格式如下: - 稀疏 x 稀疏 -> 稀疏 - 稀疏 x 稠密 -> 稠密
示例
>>> a = torch.tensor([[1., 0, 2], [0, 3, 0]]).to_sparse().requires_grad_() >>> a tensor(indices=tensor([[0, 0, 1], [0, 2, 1]]), values=tensor([1., 2., 3.]), size=(2, 3), nnz=3, layout=torch.sparse_coo, requires_grad=True) >>> b = torch.tensor([[0, 1.], [2, 0], [0, 0]], requires_grad=True) >>> b tensor([[0., 1.], [2., 0.], [0., 0.]], requires_grad=True) >>> y = torch.sparse.mm(a, b) >>> y tensor([[0., 1.], [6., 0.]], grad_fn=<SparseAddmmBackward0>) >>> y.sum().backward() >>> a.grad tensor(indices=tensor([[0, 0, 1], [0, 2, 1]]), values=tensor([1., 0., 2.]), size=(2, 3), nnz=3, layout=torch.sparse_coo) >>> c = a.detach().to_sparse_csr() >>> c tensor(crow_indices=tensor([0, 2, 3]), col_indices=tensor([0, 2, 1]), values=tensor([1., 2., 3.]), size=(2, 3), nnz=3, layout=torch.sparse_csr) >>> y1 = torch.sparse.mm(c, b, 'sum') >>> y1 tensor([[0., 1.], [6., 0.]], grad_fn=<SparseMmReduceImplBackward0>) >>> y2 = torch.sparse.mm(c, b, 'max') >>> y2 tensor([[0., 1.], [6., 0.]], grad_fn=<SparseMmReduceImplBackward0>)