torch.nn.functional.one_hot¶
- torch.nn.functional.one_hot(tensor, num_classes=-1) LongTensor¶
接收形狀為
(*)的 LongTensor,其中包含索引值,並返回一個形狀為(*, num_classes)的 tensor。此 tensor 除最後一維的索引與輸入 tensor 對應值匹配的位置為 1 外,其餘位置均為 0。另請參閱 維基百科上的 One-hot(獨熱碼)。
- 引數
tensor (LongTensor) – 任何形狀的類別值。
num_classes (int, optional) – 類別的總數量。如果設定為 -1,則類別數量將推斷為輸入 tensor 中最大類別值加一。預設值:-1
- 返回
一個 LongTensor,它在最後一維的輸入指定索引位置的值為 1,其餘位置為 0,且該維度比輸入多一維。
示例
>>> F.one_hot(torch.arange(0, 5) % 3) tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0], [0, 1, 0]]) >>> F.one_hot(torch.arange(0, 5) % 3, num_classes=5) tensor([[1, 0, 0, 0, 0], [0, 1, 0, 0, 0], [0, 0, 1, 0, 0], [1, 0, 0, 0, 0], [0, 1, 0, 0, 0]]) >>> F.one_hot(torch.arange(0, 6).view(3,2) % 3) tensor([[[1, 0, 0], [0, 1, 0]], [[0, 0, 1], [1, 0, 0]], [[0, 1, 0], [0, 0, 1]]])