快捷方式

torch.dsplit

torch.dsplit(input, indices_or_sections) 張量列表

根據 indices_or_sections 將具有三個或更多維度的張量 input 沿深度方向分割成多個張量。每個分割都是 input 的一個檢視。

這等同於呼叫 torch.tensor_split(input, indices_or_sections, dim=2)(分割維度是 2),不同之處在於如果 indices_or_sections 是一個整數,則它必須能整除分割維度,否則會丟擲執行時錯誤。

此函式基於 NumPy 的 numpy.dsplit()

引數
示例:
>>> t = torch.arange(16.0).reshape(2, 2, 4)
>>> t
tensor([[[ 0.,  1.,  2.,  3.],
         [ 4.,  5.,  6.,  7.]],
        [[ 8.,  9., 10., 11.],
         [12., 13., 14., 15.]]])
>>> torch.dsplit(t, 2)
(tensor([[[ 0.,  1.],
        [ 4.,  5.]],
       [[ 8.,  9.],
        [12., 13.]]]),
 tensor([[[ 2.,  3.],
          [ 6.,  7.]],
         [[10., 11.],
          [14., 15.]]]))
>>> torch.dsplit(t, [3, 6])
(tensor([[[ 0.,  1.,  2.],
          [ 4.,  5.,  6.]],
         [[ 8.,  9., 10.],
          [12., 13., 14.]]]),
 tensor([[[ 3.],
          [ 7.]],
         [[11.],
          [15.]]]),
 tensor([], size=(2, 2, 0)))

文件

查閱全面的 PyTorch 開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深入教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源