快捷方式

torch.take_along_dim

torch.take_along_dim(input, indices, dim=None, *, out=None) Tensor

沿著給定的 dim,根據 indices 中的一維索引從 input 中選取值。

如果 dim 為 None,則將輸入陣列視為已展平為一維。

返回沿某個維度的索引的函式,例如 torch.argmax()torch.argsort(),被設計為與此函式一起使用。請參見下面的示例。

注意

此函式類似於 NumPy 的 take_along_axis。另請參閱 torch.gather()

引數
  • input (Tensor) – 輸入張量。

  • indices (LongTensor) – input 中的索引。必須是 long 資料型別。

  • dim (int, optional) – 沿哪個維度選擇。預設值:0

關鍵字引數

out (Tensor, optional) – 輸出張量。

示例

>>> t = torch.tensor([[10, 30, 20], [60, 40, 50]])
>>> max_idx = torch.argmax(t)
>>> torch.take_along_dim(t, max_idx)
tensor([60])
>>> sorted_idx = torch.argsort(t, dim=1)
>>> torch.take_along_dim(t, sorted_idx, dim=1)
tensor([[10, 20, 30],
        [40, 50, 60]])

文件

查閱 PyTorch 的全面開發者文件

檢視文件

教程

獲取適合初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源