CatTensors¶
- class torchrl.envs.transforms.CatTensors(in_keys: 序列[NestedKey] | None = None, out_key: NestedKey = 'observation_vector', dim: int = - 1, *, del_keys: bool = True, unsqueeze_if_oor: bool = False, sort: bool = True)[源]¶
將多個鍵連線成一個張量。
當多個鍵描述單個狀態時(例如,“observation_position” 和 “observation_velocity”),這尤其有用。
- 引數:
in_keys (巢狀鍵序列) – 要連線的鍵。如果為 None(或未提供),則在首次使用此變換時將從父環境中檢索這些鍵。此行為僅在設定了父環境時有效。
out_key (NestedKey) – 結果張量的鍵。
dim (int, 可選) – 進行連線的維度。預設為
-1。
- 關鍵字引數:
del_keys (bool, 可選) – 如果為
True,輸入值將在連線後被刪除。預設為True。unsqueeze_if_oor (bool, 可選) – 如果為
True,CatTensor 將檢查要連線的張量是否存在指定維度。如果不存在,則將沿該維度對張量進行unsqueeze(擴充套件維度)。預設為False。sort (bool, 可選) – 如果為
True,鍵將在變換中進行排序。否則,將使用使用者提供的順序。預設為True。
示例
>>> transform = CatTensors(in_keys=["key1", "key2"]) >>> td = TensorDict({"key1": torch.zeros(1, 1), ... "key2": torch.ones(1, 1)}, [1]) >>> _ = transform(td) >>> print(td.get("observation_vector")) tensor([[0., 1.]]) >>> transform = CatTensors(in_keys=["key1", "key2"], dim=-2, unsqueeze_if_oor=True) >>> td = TensorDict({"key1": torch.zeros(1), ... "key2": torch.ones(1)}, []) >>> _ = transform(td) >>> print(td.get("observation_vector").shape) torch.Size([2, 1])
- forward(tensordict: TensorDictBase) TensorDictBase¶
讀取輸入的 tensordict,並對選定的鍵應用變換。
- transform_observation_spec(observation_spec: TensorSpec) TensorSpec[源]¶
變換觀察規範,使結果規範與變換對映匹配。
- 引數:
observation_spec (TensorSpec) – 變換前的規範
- 返回值:
變換後預期的規範