快捷方式

torch.where

torch.where(condition, input, other, *, out=None) Tensor

返回一個 Tensor,其元素根據 conditioninputother 中選擇。

該操作定義為

outi={inputiif conditioniotheriotherwise\text{out}_i = \begin{cases} \text{input}_i & \text{if } \text{condition}_i \\ \text{other}_i & \text{otherwise} \\ \end{cases}

注意

Tensor conditioninputother 必須是可廣播的

引數
  • condition (BoolTensor) – 當為 True(非零)時,取 input 中的值,否則取 other 中的值

  • input (TensorScalar) – 值為(如果 input 是標量)或在 conditionTrue 的索引處選擇的值

  • other (TensorScalar) – 值為(如果 other 是標量)或在 conditionFalse 的索引處選擇的值

關鍵字引數

out (Tensor, 可選) – 輸出 Tensor。

返回

一個 Tensor,其形狀等於 conditioninputother 廣播後的形狀

返回型別

Tensor

示例

>>> x = torch.randn(3, 2)
>>> y = torch.ones(3, 2)
>>> x
tensor([[-0.4620,  0.3139],
        [ 0.3898, -0.7197],
        [ 0.0478, -0.1657]])
>>> torch.where(x > 0, 1.0, 0.0)
tensor([[0., 1.],
        [1., 0.],
        [1., 0.]])
>>> torch.where(x > 0, x, y)
tensor([[ 1.0000,  0.3139],
        [ 0.3898,  1.0000],
        [ 0.0478,  1.0000]])
>>> x = torch.randn(2, 2, dtype=torch.double)
>>> x
tensor([[ 1.0779,  0.0383],
        [-0.8785, -1.1089]], dtype=torch.float64)
>>> torch.where(x > 0, x, 0.)
tensor([[1.0779, 0.0383],
        [0.0000, 0.0000]], dtype=torch.float64)
torch.where(condition) tuple of LongTensor

torch.where(condition) 等同於 torch.nonzero(condition, as_tuple=True)

注意

另請參閱 torch.nonzero()

文件

查閱 PyTorch 的完整開發者文件

檢視文件

教程

獲取針對初學者和高階開發者的深入教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源