torch.flatten¶
- torch.flatten(input, start_dim=0, end_dim=-1) Tensor¶
透過將
input重塑為一維張量來展平它。如果傳入start_dim或end_dim,則只展平從start_dim開始到end_dim結束的維度。input中元素的順序保持不變。與 NumPy 的 flatten 不同,後者總是複製輸入資料,此函式可能會返回原始物件、檢視或副本。如果未展平任何維度,則返回原始物件
input。否則,如果 input 可以被視為展平後的形狀,則返回該檢視。最後,只有當輸入不能被視為展平後的形狀時,才會複製輸入的資料。關於何時返回檢視的詳細資訊,請參閱torch.Tensor.view()。注意
展平零維張量將返回一個一維檢視。
示例
>>> t = torch.tensor([[[1, 2], ... [3, 4]], ... [[5, 6], ... [7, 8]]]) >>> torch.flatten(t) tensor([1, 2, 3, 4, 5, 6, 7, 8]) >>> torch.flatten(t, start_dim=1) tensor([[1, 2, 3, 4], [5, 6, 7, 8]])