torch.tensor_split¶
- torch.tensor_split(input, indices_or_sections, dim=0) List of Tensors¶
根據
indices_or_sections指定的索引或分段數,沿維度dim將張量拆分為多個子張量,所有子張量都是input的檢視。此函式基於 NumPy 的numpy.array_split()。- 引數
input (Tensor) – 要拆分的張量
indices_or_sections (Tensor, int 或 list 或 tuple of ints) –
如果
indices_or_sections是一個整數n或一個值為n的零維 long tensor,則input沿維度dim被拆分為n個分段。如果input沿維度dim可以被n整除,則每個分段大小相等,為input.size(dim) / n。如果input不能被n整除,則前int(input.size(dim) % n)個分段的大小為int(input.size(dim) / n) + 1,其餘分段的大小為int(input.size(dim) / n)。如果
indices_or_sections是一個 int 的列表或元組,或者一個一維的 long tensor,則input沿維度dim在列表、元組或 tensor 中的每個索引處被拆分。例如,如果indices_or_sections=[2, 3]且dim=0,將生成張量input[:2]、input[2:3]和input[3:]。如果
indices_or_sections是一個 tensor,它必須是 CPU 上的零維或一維 long tensor。dim (int, 可選) – 沿哪個維度拆分張量。預設值:
0
示例
>>> x = torch.arange(8) >>> torch.tensor_split(x, 3) (tensor([0, 1, 2]), tensor([3, 4, 5]), tensor([6, 7])) >>> x = torch.arange(7) >>> torch.tensor_split(x, 3) (tensor([0, 1, 2]), tensor([3, 4]), tensor([5, 6])) >>> torch.tensor_split(x, (1, 6)) (tensor([0]), tensor([1, 2, 3, 4, 5]), tensor([6])) >>> x = torch.arange(14).reshape(2, 7) >>> x tensor([[ 0, 1, 2, 3, 4, 5, 6], [ 7, 8, 9, 10, 11, 12, 13]]) >>> torch.tensor_split(x, 3, dim=1) (tensor([[0, 1, 2], [7, 8, 9]]), tensor([[ 3, 4], [10, 11]]), tensor([[ 5, 6], [12, 13]])) >>> torch.tensor_split(x, (1, 6), dim=1) (tensor([[0], [7]]), tensor([[ 1, 2, 3, 4, 5], [ 8, 9, 10, 11, 12]]), tensor([[ 6], [13]]))