快捷方式

CatTensors

class torchrl.envs.transforms.CatTensors(in_keys: 序列[NestedKey] | None = None, out_key: NestedKey = 'observation_vector', dim: int = - 1, *, del_keys: bool = True, unsqueeze_if_oor: bool = False, sort: bool = True)[源]

將多個鍵連線成一個張量。

當多個鍵描述單個狀態時(例如,“observation_position” 和 “observation_velocity”),這尤其有用。

引數:
  • in_keys (巢狀鍵序列) – 要連線的鍵。如果為 None(或未提供),則在首次使用此變換時將從父環境中檢索這些鍵。此行為僅在設定了父環境時有效。

  • out_key (NestedKey) – 結果張量的鍵。

  • dim (int, 可選) – 進行連線的維度。預設為 -1

關鍵字引數:
  • del_keys (bool, 可選) – 如果為 True,輸入值將在連線後被刪除。預設為 True

  • unsqueeze_if_oor (bool, 可選) – 如果為 True,CatTensor 將檢查要連線的張量是否存在指定維度。如果不存在,則將沿該維度對張量進行 unsqueeze(擴充套件維度)。預設為 False

  • sort (bool, 可選) – 如果為 True,鍵將在變換中進行排序。否則,將使用使用者提供的順序。預設為 True

示例

>>> transform = CatTensors(in_keys=["key1", "key2"])
>>> td = TensorDict({"key1": torch.zeros(1, 1),
...     "key2": torch.ones(1, 1)}, [1])
>>> _ = transform(td)
>>> print(td.get("observation_vector"))
tensor([[0., 1.]])
>>> transform = CatTensors(in_keys=["key1", "key2"], dim=-2, unsqueeze_if_oor=True)
>>> td = TensorDict({"key1": torch.zeros(1),
...     "key2": torch.ones(1)}, [])
>>> _ = transform(td)
>>> print(td.get("observation_vector").shape)
torch.Size([2, 1])
forward(tensordict: TensorDictBase) TensorDictBase

讀取輸入的 tensordict,並對選定的鍵應用變換。

transform_observation_spec(observation_spec: TensorSpec) TensorSpec[源]

變換觀察規範,使結果規範與變換對映匹配。

引數:

observation_spec (TensorSpec) – 變換前的規範

返回值:

變換後預期的規範

文件

查閱 PyTorch 的全面開發者文件

檢視文件

教程

獲取針對初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源