快捷方式

FlattenObservation

class torchrl.envs.transforms.FlattenObservation(first_dim: int, last_dim: int, in_keys: Sequence[NestedKey] | None = None, out_keys: Sequence[NestedKey] | None = None, allow_positive_dim: bool = False)[source]

展平張量的相鄰維度。

引數:
  • first_dim (int) – 要展平維度的起始維度。

  • last_dim (int) – 要展平維度的結束維度。

  • in_keys (NestedKey 序列, 可選) – 要展平的條目。如果未提供,則假定為 ["pixels"]

  • out_keys (NestedKey 序列, 可選) – 展平後的觀察鍵。如果未提供,則假定為 in_keys

  • allow_positive_dim (bool, 可選) – 如果為 True,則接受正數維度。FlattenObservation 會將這些維度對映到輸入張量的第 n 個特徵維度(即父環境批次大小後的第 n 個維度)。預設為 False,即不允許非負維度。

forward(tensordict: TensorDictBase) TensorDictBase

讀取輸入的 tensordict,並對選定的鍵應用轉換。

對於僅與父環境相關的任何操作(例如 FrameSkip),應修改 _step 方法。_call() 僅在需要修改輸入的 tensordict 時才應被覆蓋。

_call() 將由 TransformedEnv.step()TransformedEnv.reset() 呼叫。

transform_observation_spec(observation_spec: TensorSpec) TensorSpec[source]

轉換觀察 Spec,使結果 Spec 與轉換對映匹配。

引數:

observation_spec (TensorSpec) – 轉換前的 Spec

返回:

轉換後的預期 Spec

文件

查閱 PyTorch 的全面開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深入教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源