快捷方式

torch.gather

torch.gather(input, dim, index, *, sparse_grad=False, out=None) Tensor

沿著由 dim 指定的軸收集值。

對於一個 3-D 的 tensor,輸出由以下公式指定

out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

inputindex 必須具有相同的維度數。此外,要求對於所有維度 d != dim,都有 index.size(d) <= input.size(d)out 將具有與 index 相同的形狀。注意 inputindex 不會對彼此進行廣播。

引數
  • input (Tensor) – 源 tensor

  • dim (int) – 用於索引的軸

  • index (LongTensor) – 要收集的元素的索引

關鍵字引數
  • sparse_grad (bool, 可選) – 如果為 True,則關於 input 的梯度將是稀疏 tensor。

  • out (Tensor, 可選) – 目標 tensor

示例

>>> t = torch.tensor([[1, 2], [3, 4]])
>>> torch.gather(t, 1, torch.tensor([[0, 0], [1, 0]]))
tensor([[ 1,  1],
        [ 4,  3]])

文件

訪問 PyTorch 的全面開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並解答您的疑問

檢視資源