快捷方式

PermuteTransform

class torchrl.envs.transforms.PermuteTransform(dims, in_keys=None, out_keys=None, in_keys_inv=None, out_keys_inv=None)[source]

置換變換。

根據期望的維度對輸入張量進行置換。置換必須沿著特徵維度(而不是批次維度)提供。

引數:
  • dims (int 列表) – 維度的置換順序。必須是 [-(len(dims)), ..., -1] 這些維度的重新排序。

  • in_keys (NestedKeys 列表) – 輸入條目(讀取)。

  • out_keys (NestedKeys 列表) – 輸入條目(寫入)。如果未提供,預設為 in_keys

  • in_keys_inv (NestedKeys 列表) – 在呼叫 inv() 期間讀取的輸入條目。

  • out_keys_inv (NestedKeys 列表) – 在呼叫 inv() 期間寫入的輸入條目。如果未提供,預設為 in_keys_in

示例

>>> from torchrl.envs.libs.gym import GymEnv
>>> base_env = GymEnv("ALE/Pong-v5")
>>> base_env.rollout(2)
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([2, 6]), device=cpu, dtype=torch.int64, is_shared=False),
        done: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                pixels: Tensor(shape=torch.Size([2, 210, 160, 3]), device=cpu, dtype=torch.uint8, is_shared=False),
                reward: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([2]),
            device=cpu,
            is_shared=False),
        pixels: Tensor(shape=torch.Size([2, 210, 160, 3]), device=cpu, dtype=torch.uint8, is_shared=False)},
    batch_size=torch.Size([2]),
    device=cpu,
    is_shared=False)
>>> env = TransformedEnv(base_env, PermuteTransform((-1, -3, -2), in_keys=["pixels"]))
>>> env.rollout(2)  # channels are at the end
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([2, 6]), device=cpu, dtype=torch.int64, is_shared=False),
        done: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                pixels: Tensor(shape=torch.Size([2, 3, 210, 160]), device=cpu, dtype=torch.uint8, is_shared=False),
                reward: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([2]),
            device=cpu,
            is_shared=False),
        pixels: Tensor(shape=torch.Size([2, 3, 210, 160]), device=cpu, dtype=torch.uint8, is_shared=False)},
    batch_size=torch.Size([2]),
    device=cpu,
    is_shared=False)
transform_input_spec(input_spec: TensorSpec) TensorSpec[source]

轉換輸入規範,使得結果規範與轉換對映匹配。

引數:

input_spec (TensorSpec) – 轉換前的規範

返回:

轉換後預期的規範

transform_observation_spec(observation_spec: TensorSpec) TensorSpec[source]

轉換觀測規範,使得結果規範與轉換對映匹配。

引數:

observation_spec (TensorSpec) – 轉換前的規範

返回:

轉換後預期的規範

文件

查閱 PyTorch 的全面開發者文件

檢視文件

教程

獲取適合初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源