快捷方式

TimeMaxPool

class torchrl.envs.transforms.TimeMaxPool(in_keys: Sequence[NestedKey] | None = None, out_keys: Sequence[NestedKey] | None = None, T: int = 1, reset_key: NestedKey | None = None)[原始碼]

取最近 T 個觀察值的每個位置的最大值。

此變換對所有 in_keys 張量在最後 T 個時間步的每個位置取最大值。

引數:
  • in_keys (Sequence of NestedKey, 可選) – 應用最大池化的輸入鍵。如果為空,預設為“observation”。

  • out_keys (Sequence of NestedKey, 可選) – 輸出將被寫入的輸出鍵。如果為空,預設為 in_keys

  • T (int, 可選) – 應用最大池化的時間步數。

  • reset_key (NestedKey | None, 可選) – 用作部分重置指示符的重置鍵。必須是唯一的。如果未提供,則預設為父環境的唯一重置鍵(如果只有一個),否則會引發異常。

示例

>>> from torchrl.envs import GymEnv
>>> base_env = GymEnv("Pendulum-v1")
>>> env = TransformedEnv(base_env, TimeMaxPool(in_keys=["observation"], T=10))
>>> torch.manual_seed(0)
>>> env.set_seed(0)
>>> rollout = env.rollout(10)
>>> print(rollout["observation"])  # values should be increasing up until the 10th step
tensor([[ 0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0216,  0.0000],
        [ 0.0000,  0.1149,  0.0000],
        [ 0.0000,  0.1990,  0.0000],
        [ 0.0000,  0.2749,  0.0000],
        [ 0.0000,  0.3281,  0.0000],
        [-0.9290,  0.3702, -0.8978]])

注意

TimeMaxPool 目前僅支援根級別的 done 訊號。巢狀的 done,例如在 MARL 設定中發現的,目前不受支援。如果需要此功能,請在 TorchRL 倉庫中提出 Issue。

forward(tensordict: TensorDictBase) TensorDictBase[原始碼]

讀取輸入的 tensordict,並對選定的鍵應用此變換。

transform_observation_spec(observation_spec: TensorSpec) TensorSpec[原始碼]

變換觀察規範,使結果規範與變換對映匹配。

引數:

observation_spec (TensorSpec) – 變換前的規範

返回:

變換後預期的規範

文件

訪問 PyTorch 的全面開發者文件

檢視文件

教程

獲取針對初學者和高階開發者的深入教程

檢視教程

資源

查詢開發資源並獲取解答

檢視資源