ValueEstimatorBase¶
- class torchrl.objectives.value.ValueEstimatorBase(*args, **kwargs)[source]¶
值函式模組的抽象父類。
其
ValueFunctionBase.forward()方法將計算值(由值網路給出)和值估計(由值估計器給出)以及優勢,並將這些值寫入輸出 tensordict。如果只需要值估計,則應使用
ValueFunctionBase.value_estimate()。- default_keys¶
_AcceptedKeys的別名
- abstract forward(tensordict: TensorDictBase, *, params: TensorDictBase | None = None, target_params: TensorDictBase | None = None) TensorDictBase[source]¶
計算給定 tensordict 中的資料的優勢估計。
如果提供了函式式模組,則可以將包含引數(以及相關的目標引數)的巢狀 TensorDict 傳遞給該模組。
- 引數:
tensordict (TensorDictBase) – 包含計算值估計和 TDEstimate 所需資料(觀察鍵、
"action"、("next", "reward")、("next", "done")、("next", "terminated")和環境返回的"next"tensordict 狀態)的 TensorDict。傳遞給此模組的資料應結構化為[*B, T, *F],其中B是批處理大小,T是時間維度,F是特徵維度。tensordict 必須具有形狀[*B, T]。- 關鍵字引數:
params (TensorDictBase, 可選) – 一個巢狀 TensorDict,包含要傳遞給函式式值網路模組的引數。
target_params (TensorDictBase, 可選) – 一個巢狀 TensorDict,包含要傳遞給函式式值網路模組的目標引數。
device (torch.device, 可選) – 將例項化緩衝區的裝置。預設為
torch.get_default_device()。
- 返回:
一個更新後的 TensorDict,包含建構函式中定義的 advantage 和 value_error 鍵。
- value_estimate(tensordict, target_params: TensorDictBase | None = None, next_value: torch.Tensor | None = None, **kwargs)[source]¶
獲取值估計,通常用作值網路的目標值。
如果狀態值鍵存在於
tensordict.get(("next", self.tensor_keys.value))下,則將直接使用此值,而無需呼叫值網路。- 引數:
tensordict (TensorDictBase) – 包含要讀取的資料的 tensordict。
target_params (TensorDictBase, 可選) – 一個巢狀 TensorDict,包含要傳遞給函式式值網路模組的目標引數。
next_value (torch.Tensor, 可選) – 下一個狀態或狀態-動作對的值。與
target_params互斥。**kwargs – 要傳遞給值網路的關鍵字引數。
返回: 對應於狀態值的張量。