tensordict.nn.distributions.CompositeDistribution¶
- class tensordict.nn.distributions.CompositeDistribution(params: TensorDictBase, distribution_map: dict, *, name_map: Optional[dict] = None, extra_kwargs=None, aggregate_probabilities: Optional[bool] = None, log_prob_key: Optional[NestedKey] = None, entropy_key: Optional[NestedKey] = None)¶
一個使用 TensorDict 介面將多個分佈組合在一起的複合分佈。
此類允許對分佈集合執行諸如 log_prob_composite、entropy_composite、cdf、icdf、rsample 和 sample 等操作,並返回一個 TensorDict。輸入的 TensorDict 可能會被就地修改。
- 引數:
params (TensorDictBase) – 一個巢狀的鍵-張量對映,其中根條目對應於樣本名稱,葉子條目是分佈引數。條目名稱必須與 distribution_map 中指定的名稱匹配。
distribution_map (Dict[NestedKey, Type[torch.distribution.Distribution]]) – 指定要使用的分佈型別。分佈的名稱應與 TensorDict 中的樣本名稱匹配。
- 關鍵字引數:
name_map (Dict[NestedKey, NestedKey], optional) – 一個對映,指定每個樣本應寫入的位置。如果未提供,將使用 distribution_map 中的鍵名稱。
extra_kwargs (Dict[NestedKey, Dict], optional) – 用於構造分佈的額外關鍵字引數字典。
aggregate_probabilities (bool, optional) –
如果為 True,log_prob 和 entropy 方法將對各個分佈的機率和熵求和並返回單個張量。如果為 False,單個對數機率將儲存在輸入的 TensorDict 中(對於 log_prob),或作為輸出 TensorDict 的葉子返回(對於 entropy)。這可以在執行時透過將 aggregate_probabilities 引數傳遞給 log_prob 和 entropy 來覆蓋。預設為 False。
警告
此引數將在 v0.9 中棄用,屆時
tensordict.nn.probabilistic.composite_lp_aggregate()將預設為False。log_prob_key (NestedKey, optional) –
儲存聚合對數機率的鍵。預設為 ‘sample_log_prob’。
注意
如果
tensordict.nn.probabilistic.composite_lp_aggregate()返回False,則對數機率將寫入 (“path”, “to”, “leaf”, “<sample_name>_log_prob”) 下,其中 (“path”, “to”, “leaf”, “<sample_name>”) 是對應於正在取樣的葉子張量的NestedKey。在這種情況下,log_prob_key引數將被忽略。entropy_key (NestedKey, optional) –
儲存熵的鍵。預設為 ‘entropy’
注意
如果
tensordict.nn.probabilistic.composite_lp_aggregate()返回False,則熵將寫入 (“path”, “to”, “leaf”, “<sample_name>_entropy”) 下,其中 (“path”, “to”, “leaf”, “<sample_name>”) 是對應於正在取樣的葉子張量的NestedKey。在這種情況下,entropy_key引數將被忽略。
注意
包含引數(params)的輸入 TensorDict 的批次大小決定了分佈的批次形狀。例如,呼叫 log_prob 產生的 “sample_log_prob” 條目將具有引數的形狀加上任何額外的批次維度。
另請參閱
ProbabilisticTensorDictModule和ProbabilisticTensorDictSequential,瞭解如何將此類用作模型的一部分。另請參閱
set_composite_lp_aggregate,控制對數機率的聚合。示例
>>> params = TensorDict({ ... "cont": {"loc": torch.randn(3, 4), "scale": torch.rand(3, 4)}, ... ("nested", "disc"): {"logits": torch.randn(3, 10)} ... }, [3]) >>> dist = CompositeDistribution(params, ... distribution_map={"cont": d.Normal, ("nested", "disc"): d.Categorical}) >>> sample = dist.sample((4,)) >>> with set_composite_lp_aggregate(False): ... sample = dist.log_prob(sample) ... print(sample) TensorDict( fields={ cont: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False), cont_log_prob: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False), nested: TensorDict( fields={ disc: Tensor(shape=torch.Size([4, 3]), device=cpu, dtype=torch.int64, is_shared=False), disc_log_prob: Tensor(shape=torch.Size([4, 3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([4]), device=None, is_shared=False)}, batch_size=torch.Size([4]), device=None, is_shared=False)