快捷方式

torch.index_select

torch.index_select(input, dim, index, *, out=None) Tensor

返回一個新的張量,該張量使用作為 LongTensorindex 中的條目沿 dim 維度索引 input 張量。

返回的張量與原始張量(input)具有相同的維度數。dim 維度的尺寸與 index 的長度相同;其他維度的尺寸與原始張量相同。

注意

返回的張量與原始張量使用的儲存空間**不同**。out 如果形狀與預期不同,我們將默默地將其更改為正確的形狀,必要時重新分配底層儲存空間。

引數
  • input (Tensor) – 輸入張量。

  • dim (int) – 進行索引操作的維度

  • index (IntTensorLongTensor) – 包含要索引的下標的一維張量

關鍵字引數

out (Tensor, 可選) – 輸出張量。

示例

>>> x = torch.randn(3, 4)
>>> x
tensor([[ 0.1427,  0.0231, -0.5414, -1.0009],
        [-0.4664,  0.2647, -0.1228, -1.1068],
        [-1.1734, -0.6571,  0.7230, -0.6004]])
>>> indices = torch.tensor([0, 2])
>>> torch.index_select(x, 0, indices)
tensor([[ 0.1427,  0.0231, -0.5414, -1.0009],
        [-1.1734, -0.6571,  0.7230, -0.6004]])
>>> torch.index_select(x, 1, indices)
tensor([[ 0.1427, -0.5414],
        [-0.4664, -0.1228],
        [-1.1734,  0.7230]])

文件

訪問 PyTorch 的全面開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲得解答

檢視資源