torch.multinomial¶
- torch.multinomial(input, num_samples, replacement=False, *, generator=None, out=None) LongTensor¶
返回一個張量,其中每一行包含從對應於輸入張量
input中各行的多項式(更嚴格的定義是多元分佈,詳情請參閱torch.distributions.multinomial.Multinomial)機率分佈中抽取的num_samples個索引。注意
input的行不需要總和為一(在這種情況下,我們將使用值作為權重),但必須是非負數、有限數且總和非零。索引按照取樣的順序從左到右排列(第一個樣本放在第一列)。
如果
input是一個向量,out是一個大小為num_samples的向量。如果
input是一個具有 m 行的矩陣,out是一個形狀為 的矩陣。如果 replacement 為
True,則樣本進行有放回抽取。否則,樣本進行無放回抽取,這意味著對於某一行,一旦某個樣本索引被抽取,它就不能再被抽取。
注意
進行無放回抽取時,
num_samples必須小於input中的非零元素數量(如果input是矩陣,則必須小於每行中的最小非零元素數量)。- 引數
- 關鍵字引數
generator (
torch.Generator, optional) – 用於取樣的偽隨機數生成器out (Tensor, optional) – 輸出張量。
示例
>>> weights = torch.tensor([0, 10, 3, 0], dtype=torch.float) # create a tensor of weights >>> torch.multinomial(weights, 2) tensor([1, 2]) >>> torch.multinomial(weights, 5) # ERROR! RuntimeError: cannot sample n_sample > prob_dist.size(-1) samples without replacement >>> torch.multinomial(weights, 4, replacement=True) tensor([ 2, 1, 1, 1])