快捷方式

MaskedOneHotCategorical

class torchrl.modules.MaskedOneHotCategorical(logits: Optional[Tensor] = None, probs: Optional[Tensor] = None, mask: Optional[Tensor] = None, indices: Optional[Tensor] = None, neg_inf: float = - inf, padding_value: Optional[int] = None, grad_method: ReparamGradientStrategy = ReparamGradientStrategy.PassThrough)[source]

MaskedCategorical 分佈。

參考:https://www.tensorflow.org/agents/api_docs/python/tf_agents/distributions/masked/MaskedCategorical

引數:
  • logits (torch.Tensor) – 事件對數機率(未歸一化)

  • probs (torch.Tensor) – 事件機率。如果提供,將被掩碼項對應的機率置零,並在其最後一個維度上重新歸一化。

關鍵字引數:
  • mask (torch.Tensor) – 一個布林掩碼,與 logits/probs 具有相同的形狀,其中 False 條目表示被掩碼的項。或者,如果 sparse_mask 為 True,則它表示分佈中有效索引的列表。與 indices 互斥。

  • indices (torch.Tensor) – 一個表示需要考慮哪些動作的密集索引張量。與 mask 互斥。

  • neg_inf (float, optional) – 分配給無效(在掩碼之外)索引的對數機率值。預設為 -inf。

  • padding_value – 當 `sparse_mask == True` 時,掩碼張量中的填充值,該 `padding_value` 將被忽略。

  • grad_method (ReparamGradientStrategy, optional) –

    用於收集重引數化樣本的策略。 ReparamGradientStrategy.PassThrough 將計算樣本梯度

    透過使用 softmax 值的對數機率作為樣本梯度的近似。

    ReparamGradientStrategy.RelaxedOneHot 將使用 torch.distributions.RelaxedOneHot 從分佈中取樣。

示例

>>> torch.manual_seed(0)
>>> logits = torch.randn(4) / 100  # almost equal probabilities
>>> mask = torch.tensor([True, False, True, True])
>>> dist = MaskedOneHotCategorical(logits=logits, mask=mask)
>>> sample = dist.sample((10,))
>>> print(sample)  # no `1` in the sample
tensor([[0, 0, 1, 0],
        [0, 0, 0, 1],
        [1, 0, 0, 0],
        [0, 0, 1, 0],
        [0, 0, 1, 0],
        [1, 0, 0, 0],
        [0, 0, 1, 0],
        [1, 0, 0, 0],
        [0, 0, 1, 0],
        [0, 0, 1, 0]])
>>> print(dist.log_prob(sample))
tensor([-1.1203, -1.0928, -1.0831, -1.1203, -1.1203, -1.0831, -1.1203, -1.0831,
        -1.1203, -1.1203])
>>> sample_non_valid = torch.zeros_like(sample)
>>> sample_non_valid[..., 1] = 1
>>> print(dist.log_prob(sample_non_valid))
tensor([-inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf])
>>> # with probabilities
>>> prob = torch.ones(10)
>>> prob = prob / prob.sum()
>>> mask = torch.tensor([False] + 9 * [True])  # first outcome is masked
>>> dist = MaskedOneHotCategorical(probs=prob, mask=mask)
>>> s = torch.arange(10)
>>> s = torch.nn.functional.one_hot(s, 10)
>>> print(dist.log_prob(s))
tensor([   -inf, -2.1972, -2.1972, -2.1972, -2.1972, -2.1972, -2.1972, -2.1972,
        -2.1972, -2.1972])
log_prob(value: Tensor) Tensor[source]

返回在 value 處評估的機率密度/質量函式的對數。

引數:

value (Tensor) –

property mode: Tensor

返回分佈的眾數。

rsample(sample_shape: Optional[Union[Size, Sequence]] = None) Tensor[source]

如果分佈引數是批次的,則生成 `sample_shape` 形狀的重引數化樣本或 `sample_shape` 形狀的重引數化樣本批次。

sample(sample_shape: Optional[Union[Size, Sequence[int]]] = None) Tensor[source]

如果分佈引數是批次的,則生成 `sample_shape` 形狀的樣本或 `sample_shape` 形狀的樣本批次。

文件

查閱 PyTorch 的全面開發者文件

檢視文件

教程

獲取針對初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源