快捷方式

torch.multinomial

torch.multinomial(input, num_samples, replacement=False, *, generator=None, out=None) LongTensor

返回一個張量,其中每一行包含從對應於輸入張量 input 中各行的多項式(更嚴格的定義是多元分佈,詳情請參閱 torch.distributions.multinomial.Multinomial)機率分佈中抽取的 num_samples 個索引。

注意

input 的行不需要總和為一(在這種情況下,我們將使用值作為權重),但必須是非負數、有限數且總和非零。

索引按照取樣的順序從左到右排列(第一個樣本放在第一列)。

如果 input 是一個向量,out 是一個大小為 num_samples 的向量。

如果 input 是一個具有 m 行的矩陣,out 是一個形狀為 (m×num_samples)(m \times \text{num\_samples}) 的矩陣。

如果 replacement 為 True,則樣本進行有放回抽取。

否則,樣本進行無放回抽取,這意味著對於某一行,一旦某個樣本索引被抽取,它就不能再被抽取。

注意

進行無放回抽取時,num_samples 必須小於 input 中的非零元素數量(如果 input 是矩陣,則必須小於每行中的最小非零元素數量)。

引數
  • input (Tensor) – 包含機率的輸入張量

  • num_samples (int) – 要抽取的樣本數量

  • replacement (bool, optional) – 是否有放回抽取

關鍵字引數
  • generator (torch.Generator, optional) – 用於取樣的偽隨機數生成器

  • out (Tensor, optional) – 輸出張量。

示例

>>> weights = torch.tensor([0, 10, 3, 0], dtype=torch.float) # create a tensor of weights
>>> torch.multinomial(weights, 2)
tensor([1, 2])
>>> torch.multinomial(weights, 5) # ERROR!
RuntimeError: cannot sample n_sample > prob_dist.size(-1) samples without replacement
>>> torch.multinomial(weights, 4, replacement=True)
tensor([ 2,  1,  1,  1])

文件

訪問 PyTorch 全面的開發者文件

檢視文件

教程

獲取針對初學者和高階開發者的深入教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源