快捷方式

torch.nn.functional.one_hot

torch.nn.functional.one_hot(tensor, num_classes=-1) LongTensor

接收形狀為 (*) 的 LongTensor,其中包含索引值,並返回一個形狀為 (*, num_classes) 的 tensor。此 tensor 除最後一維的索引與輸入 tensor 對應值匹配的位置為 1 外,其餘位置均為 0。

另請參閱 維基百科上的 One-hot(獨熱碼)

引數
  • tensor (LongTensor) – 任何形狀的類別值。

  • num_classes (int, optional) – 類別的總數量。如果設定為 -1,則類別數量將推斷為輸入 tensor 中最大類別值加一。預設值:-1

返回

一個 LongTensor,它在最後一維的輸入指定索引位置的值為 1,其餘位置為 0,且該維度比輸入多一維。

示例

>>> F.one_hot(torch.arange(0, 5) % 3)
tensor([[1, 0, 0],
        [0, 1, 0],
        [0, 0, 1],
        [1, 0, 0],
        [0, 1, 0]])
>>> F.one_hot(torch.arange(0, 5) % 3, num_classes=5)
tensor([[1, 0, 0, 0, 0],
        [0, 1, 0, 0, 0],
        [0, 0, 1, 0, 0],
        [1, 0, 0, 0, 0],
        [0, 1, 0, 0, 0]])
>>> F.one_hot(torch.arange(0, 6).view(3,2) % 3)
tensor([[[1, 0, 0],
         [0, 1, 0]],
        [[0, 0, 1],
         [1, 0, 0]],
        [[0, 1, 0],
         [0, 0, 1]]])

文件

訪問 PyTorch 全面的開發者文件

檢視文件

教程

獲取針對初學者和高階開發者的深入教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源