FlattenObservation¶
- class torchrl.envs.transforms.FlattenObservation(first_dim: int, last_dim: int, in_keys: Sequence[NestedKey] | None = None, out_keys: Sequence[NestedKey] | None = None, allow_positive_dim: bool = False)[source]¶
展平張量的相鄰維度。
- 引數:
first_dim (int) – 要展平維度的起始維度。
last_dim (int) – 要展平維度的結束維度。
in_keys (NestedKey 序列, 可選) – 要展平的條目。如果未提供,則假定為
["pixels"]。out_keys (NestedKey 序列, 可選) – 展平後的觀察鍵。如果未提供,則假定為
in_keys。allow_positive_dim (bool, 可選) – 如果為
True,則接受正數維度。FlattenObservation會將這些維度對映到輸入張量的第 n 個特徵維度(即父環境批次大小後的第 n 個維度)。預設為 False,即不允許非負維度。
- forward(tensordict: TensorDictBase) TensorDictBase¶
讀取輸入的 tensordict,並對選定的鍵應用轉換。
對於僅與父環境相關的任何操作(例如 FrameSkip),應修改 _step 方法。
_call()僅在需要修改輸入的 tensordict 時才應被覆蓋。_call()將由TransformedEnv.step()和TransformedEnv.reset()呼叫。
- transform_observation_spec(observation_spec: TensorSpec) TensorSpec[source]¶
轉換觀察 Spec,使結果 Spec 與轉換對映匹配。
- 引數:
observation_spec (TensorSpec) – 轉換前的 Spec
- 返回:
轉換後的預期 Spec