torch.where¶
- torch.where(condition, input, other, *, out=None) Tensor¶
返回一個 Tensor,其元素根據
condition從input或other中選擇。該操作定義為
注意
Tensor
condition、input、other必須是可廣播的。- 引數
- 關鍵字引數
out (Tensor, 可選) – 輸出 Tensor。
- 返回
一個 Tensor,其形狀等於
condition、input、other廣播後的形狀- 返回型別
示例
>>> x = torch.randn(3, 2) >>> y = torch.ones(3, 2) >>> x tensor([[-0.4620, 0.3139], [ 0.3898, -0.7197], [ 0.0478, -0.1657]]) >>> torch.where(x > 0, 1.0, 0.0) tensor([[0., 1.], [1., 0.], [1., 0.]]) >>> torch.where(x > 0, x, y) tensor([[ 1.0000, 0.3139], [ 0.3898, 1.0000], [ 0.0478, 1.0000]]) >>> x = torch.randn(2, 2, dtype=torch.double) >>> x tensor([[ 1.0779, 0.0383], [-0.8785, -1.1089]], dtype=torch.float64) >>> torch.where(x > 0, x, 0.) tensor([[1.0779, 0.0383], [0.0000, 0.0000]], dtype=torch.float64)
- torch.where(condition) tuple of LongTensor
torch.where(condition)等同於torch.nonzero(condition, as_tuple=True)。注意
另請參閱
torch.nonzero()。