快捷方式

next_state_value

torchrl.objectives.next_state_value(tensordict: TensorDictBase, operator: Optional[TensorDictModule] = None, next_val_key: str = 'state_action_value', gamma: float = 0.99, pred_next_val: Optional[Tensor] = None, **kwargs)[source]

計算下一個狀態值(無梯度)以計算目標值。

目標值通常用於計算距離損失(例如 MSE)

L = Sum[ (q_value - target_value)^2 ]

目標值計算如下

r + gamma ** n_steps_to_next * value_next_state

如果獎勵是即時獎勵,則 n_steps_to_next=1。如果使用 N 步獎勵,則從輸入的 tensordict 中收集 n_steps_to_next。

引數:
  • tensordict (TensorDictBase) – 包含 reward 和 done 鍵(以及 n-steps 獎勵的 n_steps_to_next 鍵)的 Tensordict。

  • operator (ProbabilisticTDModule, optional) – 值函式運算元。呼叫時應在輸入的 tensordict 中寫入一個 ‘next_val_key’ 鍵值對。如果提供了 pred_next_val,則無需提供此引數。

  • next_val_key (str, optional) – 將寫入下一個值的鍵。預設值: ‘state_action_value’

  • gamma (float, optional) – 回報折扣率。預設值: 0.99

  • pred_next_val (Tensor, optional) – 如果未使用運算元計算,則可以提供下一個狀態值。

返回值:

一個與輸入的 tensordict 大小相同的 Tensor,包含預測的狀態值。


© 版權所有 2022, Meta。

使用 Sphinx 構建,主題由 Read the Docs 提供。

文件

訪問 PyTorch 的全面開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲得解答

檢視資源