快捷方式

torch.flatten

torch.flatten(input, start_dim=0, end_dim=-1) Tensor

透過將 input 重塑為一維張量來展平它。如果傳入 start_dimend_dim,則只展平從 start_dim 開始到 end_dim 結束的維度。input 中元素的順序保持不變。

與 NumPy 的 flatten 不同,後者總是複製輸入資料,此函式可能會返回原始物件、檢視或副本。如果未展平任何維度,則返回原始物件 input。否則,如果 input 可以被視為展平後的形狀,則返回該檢視。最後,只有當輸入不能被視為展平後的形狀時,才會複製輸入的資料。關於何時返回檢視的詳細資訊,請參閱 torch.Tensor.view()

注意

展平零維張量將返回一個一維檢視。

引數
  • input (Tensor) – 輸入張量。

  • start_dim (int) – 要展平的起始維度

  • end_dim (int) – 要展平的結束維度

示例

>>> t = torch.tensor([[[1, 2],
...                    [3, 4]],
...                   [[5, 6],
...                    [7, 8]]])
>>> torch.flatten(t)
tensor([1, 2, 3, 4, 5, 6, 7, 8])
>>> torch.flatten(t, start_dim=1)
tensor([[1, 2, 3, 4],
        [5, 6, 7, 8]])

文件

訪問 PyTorch 的全面開發者文件

檢視文件

教程

獲取針對初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源