torch.index_select¶
- torch.index_select(input, dim, index, *, out=None) Tensor¶
返回一個新的張量,該張量使用作為 LongTensor 的
index中的條目沿dim維度索引input張量。返回的張量與原始張量(
input)具有相同的維度數。dim維度的尺寸與index的長度相同;其他維度的尺寸與原始張量相同。注意
返回的張量與原始張量使用的儲存空間**不同**。
out如果形狀與預期不同,我們將默默地將其更改為正確的形狀,必要時重新分配底層儲存空間。- 引數
- 關鍵字引數
out (Tensor, 可選) – 輸出張量。
示例
>>> x = torch.randn(3, 4) >>> x tensor([[ 0.1427, 0.0231, -0.5414, -1.0009], [-0.4664, 0.2647, -0.1228, -1.1068], [-1.1734, -0.6571, 0.7230, -0.6004]]) >>> indices = torch.tensor([0, 2]) >>> torch.index_select(x, 0, indices) tensor([[ 0.1427, 0.0231, -0.5414, -1.0009], [-1.1734, -0.6571, 0.7230, -0.6004]]) >>> torch.index_select(x, 1, indices) tensor([[ 0.1427, -0.5414], [-0.4664, -0.1228], [-1.1734, 0.7230]])