快捷方式

torch.vsplit

torch.vsplit(input, indices_or_sections) 張量列表

根據 indices_or_sections 垂直地將具有兩個或更多維度的張量 input 拆分成多個張量。每個拆分都是 input 的檢視。

這等效於呼叫 torch.tensor_split(input, indices_or_sections, dim=0)(拆分維度是 0),但如果 indices_or_sections 是一個整數,則它必須能夠整除拆分維度,否則將丟擲執行時錯誤。

此函式基於 NumPy 的 numpy.vsplit()

引數
示例:
>>> t = torch.arange(16.0).reshape(4,4)
>>> t
tensor([[ 0.,  1.,  2.,  3.],
        [ 4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11.],
        [12., 13., 14., 15.]])
>>> torch.vsplit(t, 2)
(tensor([[0., 1., 2., 3.],
         [4., 5., 6., 7.]]),
 tensor([[ 8.,  9., 10., 11.],
         [12., 13., 14., 15.]]))
>>> torch.vsplit(t, [3, 6])
(tensor([[ 0.,  1.,  2.,  3.],
         [ 4.,  5.,  6.,  7.],
         [ 8.,  9., 10., 11.]]),
 tensor([[12., 13., 14., 15.]]),
 tensor([], size=(0, 4)))

文件

訪問 PyTorch 全面的開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深入教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源