快捷方式

torch.bernoulli

torch.bernoulli(input: Tensor, *, generator: Optional[Generator], out: Optional[Tensor]) Tensor

從伯努利分佈中抽取二元隨機數(0 或 1)。

input 張量應包含用於抽取二元隨機數的機率值。因此,input 中的所有值必須在以下範圍內:0inputi10 \leq \text{input}_i \leq 1

輸出張量的第 ith\text{i}^{th} 個元素將根據 input 中給定的第 ith\text{i}^{th} 個機率值抽取一個值為 11 的數。

outiBernoulli(p=inputi)\text{out}_{i} \sim \mathrm{Bernoulli}(p = \text{input}_{i})

返回的 out 張量只包含 0 或 1 的值,並且與 input 的形狀相同。

out 可以是整型 dtype,但 input 必須是浮點型 dtype

引數

input (Tensor) – 用於伯努利分佈的機率值輸入張量

關鍵字引數
  • generator (torch.Generator, optional) – 用於抽樣的偽隨機數生成器

  • out (Tensor, optional) – 輸出張量。

示例

>>> a = torch.empty(3, 3).uniform_(0, 1)  # generate a uniform random matrix with range [0, 1]
>>> a
tensor([[ 0.1737,  0.0950,  0.3609],
        [ 0.7148,  0.0289,  0.2676],
        [ 0.9456,  0.8937,  0.7202]])
>>> torch.bernoulli(a)
tensor([[ 1.,  0.,  0.],
        [ 0.,  0.,  0.],
        [ 1.,  1.,  1.]])

>>> a = torch.ones(3, 3) # probability of drawing "1" is 1
>>> torch.bernoulli(a)
tensor([[ 1.,  1.,  1.],
        [ 1.,  1.,  1.],
        [ 1.,  1.,  1.]])
>>> a = torch.zeros(3, 3) # probability of drawing "1" is 0
>>> torch.bernoulli(a)
tensor([[ 0.,  0.,  0.],
        [ 0.,  0.,  0.],
        [ 0.,  0.,  0.]])

文件

訪問 PyTorch 的完整開發者文件

檢視文件

教程

獲取適合初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源