torch.argwhere¶
- torch.argwhere(input) Tensor¶
返回一個張量,其中包含
input中所有非零元素的索引。結果中的每一行包含input中一個非零元素的索引。結果按字典順序排序,最後一個索引變化最快(C 風格)。如果
input有 維,則結果索引張量out的大小為 ,其中 是input張量中非零元素的總數。注意
此函式類似於 NumPy 的 argwhere。
當
input在 CUDA 上時,此函式會導致主機-裝置同步。- 引數
{input} –
示例
>>> t = torch.tensor([1, 0, 1]) >>> torch.argwhere(t) tensor([[0], [2]]) >>> t = torch.tensor([[1, 0, 1], [0, 1, 1]]) >>> torch.argwhere(t) tensor([[0, 0], [0, 2], [1, 1], [1, 2]])