快捷方式

torch.argwhere

torch.argwhere(input) Tensor

返回一個張量,其中包含 input 中所有非零元素的索引。結果中的每一行包含 input 中一個非零元素的索引。結果按字典順序排序,最後一個索引變化最快(C 風格)。

如果 inputnn 維,則結果索引張量 out 的大小為 (z×n)(z \times n),其中 zzinput 張量中非零元素的總數。

注意

此函式類似於 NumPy 的 argwhere

input 在 CUDA 上時,此函式會導致主機-裝置同步。

引數

{input}

示例

>>> t = torch.tensor([1, 0, 1])
>>> torch.argwhere(t)
tensor([[0],
        [2]])
>>> t = torch.tensor([[1, 0, 1], [0, 1, 1]])
>>> torch.argwhere(t)
tensor([[0, 0],
        [0, 2],
        [1, 1],
        [1, 2]])

文件

查閱 PyTorch 的全面開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲取問題解答

檢視資源