DreamerModelLoss¶
- class torchrl.objectives.DreamerModelLoss(*args, **kwargs)[原始碼]¶
Dreamer 模型損失。
計算 dreamer 世界模型的損失。該損失由 RSSM 先驗分佈與後驗分佈之間的 KL 散度、重建觀測的重建損失以及預測獎勵的獎勵損失組成。
參考文獻:https://arxiv.org/abs/1912.01603。
- 引數:
world_model (TensorDictModule) – 世界模型。
lambda_kl (
float, optional) – KL 散度損失的權重。預設值:1.0。lambda_reco (
float, optional) – 重建損失的權重。預設值:1.0。lambda_reward (
float, optional) – 獎勵損失的權重。預設值:1.0。reco_loss (str, optional) – 重建損失型別。預設值:“l2”。
reward_loss (str, optional) – 獎勵損失型別。預設值:“l2”。
free_nats (int, optional) – 自由納特。預設值:3。
delayed_clamp (bool, optional) – 如果為
True,則 KL 截斷在平均後進行。如果為 False(預設值),則 KL 散度首先截斷到 free nats 值,然後進行平均。global_average (bool, optional) – 如果為
True,則損失將在所有維度上進行平均。否則,將對所有非批次/時間維度進行求和,並在批次和時間維度上進行平均。預設值:False。
- default_keys¶
_AcceptedKeys的別名