快捷方式

DTActor

class torchrl.modules.DTActor(state_dim: int, action_dim: int, transformer_config: Dict | DecisionTransformer.DTConfig = None, device: DEVICE_TYPING | None = None)[source]

Decision Transformer Actor 類。

Decision Transformer 的 Actor 類,用於輸出確定性動作,如 “Decision Transformer” <https://arxiv.org/abs/2202.05607.pdf> 中所述。返回確定性動作。

引數:
  • state_dim (int) – 狀態維度。

  • action_dim (int) – 動作維度。

  • transformer_config (Dict or DecisionTransformer.DTConfig, optional) – GPT2 transformer 的配置。預設為 default_config()

  • device (torch.device, optional) – 要使用的裝置。預設為 None。

示例

>>> model = DTActor(state_dim=4, action_dim=2,
...     transformer_config=DTActor.default_config())
>>> observation = torch.randn(32, 10, 4)
>>> action = torch.randn(32, 10, 2)
>>> return_to_go = torch.randn(32, 10, 1)
>>> output = model(observation, action, return_to_go)
>>> output.shape
torch.Size([32, 10, 2])
classmethod default_config()[source]

DTActor 的預設配置。

forward(observation: Tensor, action: Tensor, return_to_go: Tensor) Tensor[source]

定義每次呼叫時執行的計算。

應由所有子類重寫。

注意

雖然前向傳播的實現需要在此函式中定義,但之後應該呼叫 Module 例項,而不是直接呼叫此函式,因為前者負責執行已註冊的鉤子,而後者會默默忽略它們。

文件

檢視 PyTorch 的全面開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源