快捷方式

MaskedCategorical

class torchrl.modules.MaskedCategorical(logits: Optional[Tensor] = None, probs: Optional[Tensor] = None, *, mask: Optional[Tensor] = None, indices: Optional[Tensor] = None, neg_inf: float = - inf, padding_value: Optional[int] = None)[source]

MaskedCategorical 分佈。

參考: https://www.tensorflow.org/agents/api_docs/python/tf_agents/distributions/masked/MaskedCategorical

引數:
  • logits (torch.Tensor) – 事件對數機率(未歸一化)

  • probs (torch.Tensor) – 事件機率。如果提供此引數,則對應於被遮蔽項的機率將歸零,並在其最後一個維度上重新歸一化機率。

關鍵字引數:
  • mask (torch.Tensor) – 一個與 logits/probs 同形的布林掩碼,其中 False 條目是被遮蔽的。或者,如果 sparse_mask 為 True,則它表示分佈中有效索引的列表。與 indices 互斥。

  • indices (torch.Tensor) – 一個表示哪些動作必須考慮在內的密集索引張量。與 mask 互斥。

  • neg_inf (float, optional) – 分配給無效(超出掩碼範圍)索引的對數機率值。預設為 -inf。

  • padding_value – 掩碼張量中的填充值。當 sparse_mask == True 時,將忽略 padding_value。

  • torch.manual_seed (>>>) –

  • torch.randn (>>> logits =) –

  • torch.tensor (>>> mask =) –

  • MaskedCategorical (>>> dist =) –

  • dist.sample (>>> sample =) –

  • print (>>>) –

  • tensor ([-inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf]) –

  • print

  • -1.0831, (tensor([-1.1203, -1.0928, -1.0831, -1.1203, -1.1203, -1.0831, -1.1203,) – -1.1203, -1.1203])

  • print

  • tensor

  • probabilities (>>> # 使用機率) –

  • torch.ones (>>> prob =) –

  • prob.sum() (>>> prob = prob /) –

  • torch.tensor

  • MaskedCategorical

  • print

  • -2.1972, (tensor([ -inf, -2.1972, -2.1972, -2.1972, -2.1972, -2.1972, -2.1972,) – -2.1972, -2.1972])

log_prob(value: Tensor) Tensor[source]

返回在 value 處評估的機率密度/質量函式的對數。

引數:

value (張量) –

sample(sample_shape: Optional[Union[Size, Sequence[int]]] = None) Tensor[source]

如果分佈引數是批次化的,則生成 sample_shape 形狀的樣本或 sample_shape 形狀的批次樣本。

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深入教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源