快捷方式

ValueEstimatorBase

class torchrl.objectives.value.ValueEstimatorBase(*args, **kwargs)[source]

值函式模組的抽象父類。

ValueFunctionBase.forward() 方法將計算值(由值網路給出)和值估計(由值估計器給出)以及優勢,並將這些值寫入輸出 tensordict。

如果只需要值估計,則應使用 ValueFunctionBase.value_estimate()

default_keys

_AcceptedKeys 的別名

abstract forward(tensordict: TensorDictBase, *, params: TensorDictBase | None = None, target_params: TensorDictBase | None = None) TensorDictBase[source]

計算給定 tensordict 中的資料的優勢估計。

如果提供了函式式模組,則可以將包含引數(以及相關的目標引數)的巢狀 TensorDict 傳遞給該模組。

引數:

tensordict (TensorDictBase) – 包含計算值估計和 TDEstimate 所需資料(觀察鍵、"action"("next", "reward")("next", "done")("next", "terminated") 和環境返回的 "next" tensordict 狀態)的 TensorDict。傳遞給此模組的資料應結構化為 [*B, T, *F],其中 B 是批處理大小,T 是時間維度,F 是特徵維度。tensordict 必須具有形狀 [*B, T]

關鍵字引數:
  • params (TensorDictBase, 可選) – 一個巢狀 TensorDict,包含要傳遞給函式式值網路模組的引數。

  • target_params (TensorDictBase, 可選) – 一個巢狀 TensorDict,包含要傳遞給函式式值網路模組的目標引數。

  • device (torch.device, 可選) – 將例項化緩衝區的裝置。預設為 torch.get_default_device()

返回:

一個更新後的 TensorDict,包含建構函式中定義的 advantage 和 value_error 鍵。

set_keys(**kwargs) None[source]

設定 tensordict 鍵名。

value_estimate(tensordict, target_params: TensorDictBase | None = None, next_value: torch.Tensor | None = None, **kwargs)[source]

獲取值估計,通常用作值網路的目標值。

如果狀態值鍵存在於 tensordict.get(("next", self.tensor_keys.value)) 下,則將直接使用此值,而無需呼叫值網路。

引數:
  • tensordict (TensorDictBase) – 包含要讀取的資料的 tensordict。

  • target_params (TensorDictBase, 可選) – 一個巢狀 TensorDict,包含要傳遞給函式式值網路模組的目標引數。

  • next_value (torch.Tensor, 可選) – 下一個狀態或狀態-動作對的值。與 target_params 互斥。

  • **kwargs – 要傳遞給值網路的關鍵字引數。

返回: 對應於狀態值的張量。

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

獲取針對初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源