PermuteTransform¶
- class torchrl.envs.transforms.PermuteTransform(dims, in_keys=None, out_keys=None, in_keys_inv=None, out_keys_inv=None)[source]¶
置換變換。
根據期望的維度對輸入張量進行置換。置換必須沿著特徵維度(而不是批次維度)提供。
- 引數:
dims (int 列表) – 維度的置換順序。必須是
[-(len(dims)), ..., -1]這些維度的重新排序。in_keys (NestedKeys 列表) – 輸入條目(讀取)。
out_keys (NestedKeys 列表) – 輸入條目(寫入)。如果未提供,預設為
in_keys。in_keys_inv (NestedKeys 列表) – 在呼叫
inv()期間讀取的輸入條目。out_keys_inv (NestedKeys 列表) – 在呼叫
inv()期間寫入的輸入條目。如果未提供,預設為in_keys_in。
示例
>>> from torchrl.envs.libs.gym import GymEnv >>> base_env = GymEnv("ALE/Pong-v5") >>> base_env.rollout(2) TensorDict( fields={ action: Tensor(shape=torch.Size([2, 6]), device=cpu, dtype=torch.int64, is_shared=False), done: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False), next: TensorDict( fields={ done: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False), pixels: Tensor(shape=torch.Size([2, 210, 160, 3]), device=cpu, dtype=torch.uint8, is_shared=False), reward: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([2]), device=cpu, is_shared=False), pixels: Tensor(shape=torch.Size([2, 210, 160, 3]), device=cpu, dtype=torch.uint8, is_shared=False)}, batch_size=torch.Size([2]), device=cpu, is_shared=False) >>> env = TransformedEnv(base_env, PermuteTransform((-1, -3, -2), in_keys=["pixels"])) >>> env.rollout(2) # channels are at the end TensorDict( fields={ action: Tensor(shape=torch.Size([2, 6]), device=cpu, dtype=torch.int64, is_shared=False), done: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False), next: TensorDict( fields={ done: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False), pixels: Tensor(shape=torch.Size([2, 3, 210, 160]), device=cpu, dtype=torch.uint8, is_shared=False), reward: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([2]), device=cpu, is_shared=False), pixels: Tensor(shape=torch.Size([2, 3, 210, 160]), device=cpu, dtype=torch.uint8, is_shared=False)}, batch_size=torch.Size([2]), device=cpu, is_shared=False)
- transform_input_spec(input_spec: TensorSpec) TensorSpec[source]¶
轉換輸入規範,使得結果規範與轉換對映匹配。
- 引數:
input_spec (TensorSpec) – 轉換前的規範
- 返回:
轉換後預期的規範
- transform_observation_spec(observation_spec: TensorSpec) TensorSpec[source]¶
轉換觀測規範,使得結果規範與轉換對映匹配。
- 引數:
observation_spec (TensorSpec) – 轉換前的規範
- 返回:
轉換後預期的規範