快捷方式

torch.split

torch.split(tensor, split_size_or_sections, dim=0)[source][source]

將張量分割成若干塊。每個塊是原始張量的一個檢視。

如果 split_size_or_sections 是整數型別,則 tensor 將被分割成大小相等的塊(如果可能)。如果張量沿給定維度 dim 的大小不能被 split_size 整除,則最後一個塊會較小。

如果 split_size_or_sections 是列表,則 tensor 將被分割成 len(split_size_or_sections) 個塊,這些塊在 dim 維的大小根據 split_size_or_sections 定義。

引數
  • tensor (Tensor) – 要分割的張量。

  • split_size_or_sections (int) 或 (list(int)) – 單個塊的大小或每個塊的大小列表

  • dim (int) – 沿哪個維度分割張量。

返回型別

tuple[torch.Tensor, …]

示例

>>> a = torch.arange(10).reshape(5, 2)
>>> a
tensor([[0, 1],
        [2, 3],
        [4, 5],
        [6, 7],
        [8, 9]])
>>> torch.split(a, 2)
(tensor([[0, 1],
         [2, 3]]),
 tensor([[4, 5],
         [6, 7]]),
 tensor([[8, 9]]))
>>> torch.split(a, [1, 4])
(tensor([[0, 1]]),
 tensor([[2, 3],
         [4, 5],
         [6, 7],
         [8, 9]]))

文件

訪問 PyTorch 的全面開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並解答您的問題

檢視資源