MaskedCategorical¶
- class torchrl.modules.MaskedCategorical(logits: Optional[Tensor] = None, probs: Optional[Tensor] = None, *, mask: Optional[Tensor] = None, indices: Optional[Tensor] = None, neg_inf: float = - inf, padding_value: Optional[int] = None)[source]¶
MaskedCategorical 分佈。
參考: https://www.tensorflow.org/agents/api_docs/python/tf_agents/distributions/masked/MaskedCategorical
- 引數:
logits (torch.Tensor) – 事件對數機率(未歸一化)
probs (torch.Tensor) – 事件機率。如果提供此引數,則對應於被遮蔽項的機率將歸零,並在其最後一個維度上重新歸一化機率。
- 關鍵字引數:
mask (torch.Tensor) – 一個與
logits/probs同形的布林掩碼,其中False條目是被遮蔽的。或者,如果sparse_mask為 True,則它表示分佈中有效索引的列表。與indices互斥。indices (torch.Tensor) – 一個表示哪些動作必須考慮在內的密集索引張量。與
mask互斥。neg_inf (
float, optional) – 分配給無效(超出掩碼範圍)索引的對數機率值。預設為 -inf。padding_value – 掩碼張量中的填充值。當 sparse_mask == True 時,將忽略 padding_value。
torch.manual_seed (>>>) –
torch.randn (>>> logits =) –
torch.tensor (>>> mask =) –
MaskedCategorical (>>> dist =) –
dist.sample (>>> sample =) –
print (>>>) –
tensor ([-inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf]) –
print –
-1.0831, (tensor([-1.1203, -1.0928, -1.0831, -1.1203, -1.1203, -1.0831, -1.1203,) – -1.1203, -1.1203])
print –
tensor –
probabilities (>>> # 使用機率) –
torch.ones (>>> prob =) –
prob.sum() (>>> prob = prob /) –
torch.tensor –
MaskedCategorical –
print –
-2.1972, (tensor([ -inf, -2.1972, -2.1972, -2.1972, -2.1972, -2.1972, -2.1972,) – -2.1972, -2.1972])