快捷方式

OneHotCategorical

torchrl.modules.OneHotCategorical(logits: Optional[Tensor] = None, probs: Optional[Tensor] = None, grad_method: ReparamGradientStrategy = ReparamGradientStrategy.PassThrough, **kwargs)[source]

獨熱分類分佈。

此類行為與 torch.distributions.Categorical 完全一致,區別在於它讀取並生成離散張量的獨熱編碼。

引數:
  • logits (torch.Tensor) – 事件的對數機率(未歸一化)

  • probs (torch.Tensor) – 事件機率

  • grad_method (ReparamGradientStrategy, 可選) –

    收集重引數化樣本的策略。ReparamGradientStrategy.PassThrough 將透過使用 softmax 值的對數機率作為樣本梯度的代理來計算樣本梯度。

    ReparamGradientStrategy.RelaxedOneHot 將使用 torch.distributions.RelaxedOneHot 從分佈中進行取樣。

示例

>>> torch.manual_seed(0)
>>> logits = torch.randn(4)
>>> dist = OneHotCategorical(logits=logits)
>>> print(dist.rsample((3,)))
tensor([[1., 0., 0., 0.],
        [0., 0., 0., 1.],
        [1., 0., 0., 0.]])
entropy()[source]

返回分佈的熵,已按 batch_shape 批處理。

返回:

形狀為 batch_shape 的 Tensor。

log_prob(value: Tensor) Tensor[source]

返回在 value 處評估的機率密度/質量函式的對數。

引數:

value (Tensor) –

屬性 mode: Tensor

返回分佈的眾數。

rsample(sample_shape: Optional[Union[Size, Sequence]] = None) Tensor[source]

生成 sample_shape 形狀的重引數化樣本,或者如果分佈引數是批處理的,則生成 sample_shape 形狀的重引數化樣本批次。

sample(sample_shape: Optional[Union[Size, Sequence]] = None) Tensor[source]

生成 sample_shape 形狀的樣本,或者如果分佈引數是批處理的,則生成 sample_shape 形狀的樣本批次。

文件

訪問 PyTorch 的完整開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲取解答

檢視資源