DiscreteActionProjection¶
- 類 torchrl.envs.transforms.DiscreteActionProjection(num_actions_effective: int, max_actions: int, action_key: NestedKey = 'action', include_forward: bool = True)[source]¶
將離散動作從高維空間投影到低維空間。
給定一個編碼為獨熱向量(從 1 到 N)的離散動作和一個最大動作索引 num_actions(其中 num_actions < N),該變換會使輸出動作 action_out 最大為 num_actions。
如果輸入動作大於 num_actions,則將其替換為 0 到 num_actions-1 之間的隨機值。否則,保持原動作不變。這旨在用於應用於具有不同動作空間的多個離散控制環境的策略。
呼叫 DiscreteActionProjection.forward(例如從回放緩衝區或在 nn.Modules 序列中)將在
"in_keys"上呼叫 num_actions_effective -> max_actions 變換,而對 _call 的呼叫將被忽略。實際上,變換後的環境被指示僅更新內部 base_env 的輸入鍵,但原始輸入鍵將保持不變。- 引數:
num_actions_effective (int) – 考慮的最大動作數。
max_actions (int) – 此模組可讀取的最大動作數。
action_key (NestedKey, 可選) – 動作的鍵名。預設為“action”。
include_forward (bool, 可選) – 如果為
True,當模組由回放緩衝區或 nn.Module 鏈呼叫時,對 forward 的呼叫也會將動作從一個域對映到另一個域。預設為 True。
示例
>>> torch.manual_seed(0) >>> N = 3 >>> M = 2 >>> action = torch.zeros(N, dtype=torch.long) >>> action[-1] = 1 >>> td = TensorDict({"action": action}, []) >>> transform = DiscreteActionProjection(num_actions_effective=M, max_actions=N) >>> _ = transform.inv(td) >>> print(td.get("action")) tensor([1])
- transform_input_spec(input_spec: Composite)[source]¶
變換輸入規範,使結果規範與變換對映匹配。
- 引數:
input_spec (TensorSpec) – 變換前的規範
- 返回:
變換後預期的規範