MaskedOneHotCategorical¶
- class torchrl.modules.MaskedOneHotCategorical(logits: Optional[Tensor] = None, probs: Optional[Tensor] = None, mask: Optional[Tensor] = None, indices: Optional[Tensor] = None, neg_inf: float = - inf, padding_value: Optional[int] = None, grad_method: ReparamGradientStrategy = ReparamGradientStrategy.PassThrough)[source]¶
MaskedCategorical 分佈。
參考:https://www.tensorflow.org/agents/api_docs/python/tf_agents/distributions/masked/MaskedCategorical
- 引數:
logits (torch.Tensor) – 事件對數機率(未歸一化)
probs (torch.Tensor) – 事件機率。如果提供,將被掩碼項對應的機率置零,並在其最後一個維度上重新歸一化。
- 關鍵字引數:
mask (torch.Tensor) – 一個布林掩碼,與
logits/probs具有相同的形狀,其中False條目表示被掩碼的項。或者,如果sparse_mask為 True,則它表示分佈中有效索引的列表。與indices互斥。indices (torch.Tensor) – 一個表示需要考慮哪些動作的密集索引張量。與
mask互斥。neg_inf (
float, optional) – 分配給無效(在掩碼之外)索引的對數機率值。預設為 -inf。padding_value – 當 `sparse_mask == True` 時,掩碼張量中的填充值,該 `padding_value` 將被忽略。
grad_method (ReparamGradientStrategy, optional) –
用於收集重引數化樣本的策略。
ReparamGradientStrategy.PassThrough將計算樣本梯度透過使用 softmax 值的對數機率作為樣本梯度的近似。
ReparamGradientStrategy.RelaxedOneHot將使用torch.distributions.RelaxedOneHot從分佈中取樣。
示例
>>> torch.manual_seed(0) >>> logits = torch.randn(4) / 100 # almost equal probabilities >>> mask = torch.tensor([True, False, True, True]) >>> dist = MaskedOneHotCategorical(logits=logits, mask=mask) >>> sample = dist.sample((10,)) >>> print(sample) # no `1` in the sample tensor([[0, 0, 1, 0], [0, 0, 0, 1], [1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 1, 0], [1, 0, 0, 0], [0, 0, 1, 0], [1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 1, 0]]) >>> print(dist.log_prob(sample)) tensor([-1.1203, -1.0928, -1.0831, -1.1203, -1.1203, -1.0831, -1.1203, -1.0831, -1.1203, -1.1203]) >>> sample_non_valid = torch.zeros_like(sample) >>> sample_non_valid[..., 1] = 1 >>> print(dist.log_prob(sample_non_valid)) tensor([-inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf]) >>> # with probabilities >>> prob = torch.ones(10) >>> prob = prob / prob.sum() >>> mask = torch.tensor([False] + 9 * [True]) # first outcome is masked >>> dist = MaskedOneHotCategorical(probs=prob, mask=mask) >>> s = torch.arange(10) >>> s = torch.nn.functional.one_hot(s, 10) >>> print(dist.log_prob(s)) tensor([ -inf, -2.1972, -2.1972, -2.1972, -2.1972, -2.1972, -2.1972, -2.1972, -2.1972, -2.1972])