split_trajectories¶
- torchrl.collectors.utils.split_trajectories(rollout_tensordict: TensorDictBase, *, prefix=None, trajectory_key: NestedKey | None = None, done_key: NestedKey | None = None, as_nested: bool = False) TensorDictBase[source]¶
一個用於軌跡分離的工具函式。
接收一個 tensordict,其中包含一個
traj_ids鍵,該鍵指示每個軌跡的 ID。在此基礎上,構建一個 B x T x … 零填充的 tensordict,其中 B 為批次大小,T 為最大持續時間。
- 引數:
rollout_tensordict (TensorDictBase) – 沿最後一個維度包含相鄰軌跡的 rollout。
- 關鍵字引數:
prefix (NestedKey, optional) – 用於讀取和寫入元資料的字首,例如
"traj_ids"(每個軌跡的可選整數 ID)以及指示哪些資料有效、哪些無效的"mask"條目。如果輸入包含"collector"條目,則預設為"collector",否則為()(無字首)。prefix作為遺留功能保留,最終將被棄用。儘可能優先使用trajectory_key或done_key。trajectory_key (NestedKey, optional) – 指向軌跡 ID 的鍵。覆蓋
done_key和prefix。如果未提供,則預設為(prefix, "traj_ids")。done_key (NestedKey, optional) – 指向
"done"訊號的鍵,如果無法直接恢復軌跡。預設為"done"。as_nested (bool or torch.layout, optional) –
是否將結果作為巢狀張量返回。預設為
False。如果提供了torch.layout,將使用它來構建巢狀張量,否則將使用預設佈局。注意
使用
split_trajectories(tensordict, as_nested=True).to_padded_tensor(mask=mask_key)應該會得到與as_nested=False完全相同的結果。由於這是一個實驗性功能,並且依賴於 nested_tensors,其 API 未來可能會更改,因此我們將其設為可選功能。當as_nested=True時,執行時應該更快。注意
提供佈局允許使用者控制巢狀張量是使用
torch.strided佈局還是torch.jagged佈局。儘管在撰寫本文時前者具有稍微更多的功能,但後者因其與compile()更好的相容性,未來將成為 PyTorch 團隊的主要關注點。
- 返回值:
一個新的 tensordict,其前導維度對應於軌跡。同時添加了一個
"mask"布林值條目,它共享trajectory_key字首和 tensordict 形狀,並指示 tensordict 中的有效元素。如果找不到trajectory_key,則還會新增一個"traj_ids"條目。
示例
>>> from tensordict import TensorDict >>> import torch >>> from torchrl.collectors.utils import split_trajectories >>> obs = torch.cat([torch.arange(10), torch.arange(5)]) >>> obs_ = torch.cat([torch.arange(1, 11), torch.arange(1, 6)]) >>> done = torch.zeros(15, dtype=torch.bool) >>> done[9] = True >>> trajectory_id = torch.cat([torch.zeros(10, dtype=torch.int32), ... torch.ones(5, dtype=torch.int32)]) >>> data = TensorDict({"obs": obs, ("next", "obs"): obs_, ("next", "done"): done, "trajectory": trajectory_id}, batch_size=[15]) >>> data_split = split_trajectories(data, done_key="done") >>> print(data_split) TensorDict( fields={ mask: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.bool, is_shared=False), next: TensorDict( fields={ done: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.bool, is_shared=False), obs: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([2, 10]), device=None, is_shared=False), obs: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.int64, is_shared=False), traj_ids: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.int64, is_shared=False), trajectory: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.int32, is_shared=False)}, batch_size=torch.Size([2, 10]), device=None, is_shared=False) >>> # check that split_trajectories got the trajectories right with the done signal >>> assert (data_split["traj_ids"] == data_split["trajectory"]).all() >>> print(data_split["mask"]) tensor([[ True, True, True, True, True, True, True, True, True, True], [ True, True, True, True, True, False, False, False, False, False]]) >>> data_split = split_trajectories(data, trajectory_key="trajectory") >>> print(data_split) TensorDict( fields={ mask: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.bool, is_shared=False), next: TensorDict( fields={ done: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.bool, is_shared=False), obs: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([2, 10]), device=None, is_shared=False), obs: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.int64, is_shared=False), trajectory: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.int32, is_shared=False)}, batch_size=torch.Size([2, 10]), device=None, is_shared=False)