TimeMaxPool¶
- class torchrl.envs.transforms.TimeMaxPool(in_keys: Sequence[NestedKey] | None = None, out_keys: Sequence[NestedKey] | None = None, T: int = 1, reset_key: NestedKey | None = None)[原始碼]¶
取最近 T 個觀察值的每個位置的最大值。
此變換對所有 in_keys 張量在最後 T 個時間步的每個位置取最大值。
- 引數:
in_keys (Sequence of NestedKey, 可選) – 應用最大池化的輸入鍵。如果為空,預設為“observation”。
out_keys (Sequence of NestedKey, 可選) – 輸出將被寫入的輸出鍵。如果為空,預設為 in_keys。
T (int, 可選) – 應用最大池化的時間步數。
reset_key (NestedKey | None, 可選) – 用作部分重置指示符的重置鍵。必須是唯一的。如果未提供,則預設為父環境的唯一重置鍵(如果只有一個),否則會引發異常。
示例
>>> from torchrl.envs import GymEnv >>> base_env = GymEnv("Pendulum-v1") >>> env = TransformedEnv(base_env, TimeMaxPool(in_keys=["observation"], T=10)) >>> torch.manual_seed(0) >>> env.set_seed(0) >>> rollout = env.rollout(10) >>> print(rollout["observation"]) # values should be increasing up until the 10th step tensor([[ 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0216, 0.0000], [ 0.0000, 0.1149, 0.0000], [ 0.0000, 0.1990, 0.0000], [ 0.0000, 0.2749, 0.0000], [ 0.0000, 0.3281, 0.0000], [-0.9290, 0.3702, -0.8978]])
注意
TimeMaxPool目前僅支援根級別的done訊號。巢狀的done,例如在 MARL 設定中發現的,目前不受支援。如果需要此功能,請在 TorchRL 倉庫中提出 Issue。- transform_observation_spec(observation_spec: TensorSpec) TensorSpec[原始碼]¶
變換觀察規範,使結果規範與變換對映匹配。
- 引數:
observation_spec (TensorSpec) – 變換前的規範
- 返回:
變換後預期的規範