torch.split¶
- torch.split(tensor, split_size_or_sections, dim=0)[source][source]¶
將張量分割成若干塊。每個塊是原始張量的一個檢視。
如果
split_size_or_sections是整數型別,則tensor將被分割成大小相等的塊(如果可能)。如果張量沿給定維度dim的大小不能被split_size整除,則最後一個塊會較小。如果
split_size_or_sections是列表,則tensor將被分割成len(split_size_or_sections)個塊,這些塊在dim維的大小根據split_size_or_sections定義。- 引數
- 返回型別
tuple[torch.Tensor, …]
示例
>>> a = torch.arange(10).reshape(5, 2) >>> a tensor([[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]]) >>> torch.split(a, 2) (tensor([[0, 1], [2, 3]]), tensor([[4, 5], [6, 7]]), tensor([[8, 9]])) >>> torch.split(a, [1, 4]) (tensor([[0, 1]]), tensor([[2, 3], [4, 5], [6, 7], [8, 9]]))