快捷方式

MCTSForest

class torchrl.data.MCTSForest(*, data_map: TensorDictMap | None = None, node_map: TensorDictMap | None = None, max_size: int | None = None, done_keys: List[NestedKey] | None = None, reward_keys: List[NestedKey] = None, observation_keys: List[NestedKey] = None, action_keys: List[NestedKey] = None, excluded_keys: List[NestedKey] = None, consolidated: bool | None = None)[原始檔]

MCTS 樹的集合。

警告

此類目前正處於積極開發中。API 可能會頻繁更改,請注意。

此類旨在將 rollout 儲存在 storage 中,並基於該資料集中給定的根節點生成樹。

關鍵詞引數:
  • data_map (TensorDictMap, 可選) – 用於儲存資料(觀測、獎勵、狀態等)的 storage。如果未提供,將使用 observation_keysaction_keys 的列表作為 in_keys,透過 from_tensordict_pair() 進行懶載入初始化。

  • node_map (TensorDictMap, 可選) – 將觀測空間對映到索引空間的 map。在內部,node map 用於收集從給定節點發出的所有可能的 branches。例如,如果一個觀測在 data map 中有兩個相關的 actions 和 outcomes,那麼 node_map 將返回一個數據結構,其中包含 data_map 中與這兩個 outcomes 對應的兩個索引。如果未提供,將使用 observation_keys 的列表作為 in_keys,並使用 QueryModule 作為 out_keys,透過 from_tensordict_pair() 進行懶載入初始化。

  • max_size (int, 可選) – maps 的大小。如果未提供,則預設為 data_map.max_size(如果可找到),然後是 node_map.max_size。如果這些都未提供,則預設為 1000

  • done_keys (NestedKey 列表, 可選) – 環境的 done keys。如果未提供,則預設為 ("done", "terminated", "truncated")。可以使用 get_keys_from_env() 自動確定 keys。

  • action_keys (NestedKey 列表, 可選) – 環境的 action keys。如果未提供,則預設為 ("action",)。可以使用 get_keys_from_env() 自動確定 keys。

  • reward_keys (NestedKey 列表, 可選) – 環境的 reward keys。如果未提供,則預設為 ("reward",)。可以使用 get_keys_from_env() 自動確定 keys。

  • observation_keys (NestedKey 列表, 可選) – 環境的 observation keys。如果未提供,則預設為 ("observation",)。可以使用 get_keys_from_env() 自動確定 keys。

  • excluded_keys (NestedKey 列表, 可選) – 要從資料 storage 中排除的 keys 列表。

  • consolidated (bool, 可選) – 如果為 True,則 data_map storage 將在磁碟上進行 consolidated。預設為 False

示例

>>> from torchrl.envs import GymEnv
>>> import torch
>>> from tensordict import TensorDict, LazyStackedTensorDict
>>> from torchrl.data import TensorDictMap, ListStorage
>>> from torchrl.data.map.tree import MCTSForest
>>>
>>> from torchrl.envs import PendulumEnv, CatTensors, UnsqueezeTransform, StepCounter
>>> # Create the MCTS Forest
>>> forest = MCTSForest()
>>> # Create an environment. We're using a stateless env to be able to query it at any given state (like an oracle)
>>> env = PendulumEnv()
>>> obs_keys = list(env.observation_spec.keys(True, True))
>>> state_keys = set(env.full_state_spec.keys(True, True)) - set(obs_keys)
>>> # Appending transforms to get an "observation" key that concatenates the observations together
>>> env = env.append_transform(
...     UnsqueezeTransform(
...         in_keys=obs_keys,
...         out_keys=[("unsqueeze", key) for key in obs_keys],
...         dim=-1
...     )
... )
>>> env = env.append_transform(
...     CatTensors([("unsqueeze", key) for key in obs_keys], "observation")
... )
>>> env = env.append_transform(StepCounter())
>>> env.set_seed(0)
>>> # Get a reset state, then make a rollout out of it
>>> reset_state = env.reset()
>>> rollout0 = env.rollout(6, auto_reset=False, tensordict=reset_state.clone())
>>> # Append the rollout to the forest. We're removing the state entries for clarity
>>> rollout0 = rollout0.copy()
>>> rollout0.exclude(*state_keys, inplace=True).get("next").exclude(*state_keys, inplace=True)
>>> forest.extend(rollout0)
>>> # The forest should have 6 elements (the length of the rollout)
>>> assert len(forest) == 6
>>> # Let's make another rollout from the same reset state
>>> rollout1 = env.rollout(6, auto_reset=False, tensordict=reset_state.clone())
>>> rollout1.exclude(*state_keys, inplace=True).get("next").exclude(*state_keys, inplace=True)
>>> forest.extend(rollout1)
>>> assert len(forest) == 12
>>> # Let's make another final rollout from an intermediate step in the second rollout
>>> rollout1b = env.rollout(6, auto_reset=False, tensordict=rollout1[3].exclude("next"))
>>> rollout1b.exclude(*state_keys, inplace=True)
>>> rollout1b.get("next").exclude(*state_keys, inplace=True)
>>> forest.extend(rollout1b)
>>> assert len(forest) == 18
>>> # Since we have 2 rollouts starting at the same state, our tree should have two
>>> #  branches if we produce it from the reset entry. Take the state, and call `get_tree`:
>>> r = rollout0[0]
>>> # Let's get the compact tree that follows the initial reset. A compact tree is
>>> #  a tree where nodes that have a single child are collapsed.
>>> tree = forest.get_tree(r)
>>> print(tree.max_length())
2
>>> print(list(tree.valid_paths()))
[(0,), (1, 0), (1, 1)]
>>> from tensordict import assert_close
>>> # We can manually rebuild the tree
>>> assert_close(
...     rollout1,
...     torch.cat([tree.subtree[1].rollout, tree.subtree[1].subtree[0].rollout]),
...     intersection=True,
... )
True
>>> # Or we can rebuild it using the dedicated method
>>> assert_close(
...     rollout1,
...     tree.rollout_from_path((1, 0)),
...     intersection=True,
... )
True
>>> tree.plot()
>>> tree = forest.get_tree(r, compact=False)
>>> print(tree.max_length())
9
>>> print(list(tree.valid_paths()))
[(0, 0, 0, 0, 0, 0), (1, 0, 0, 0, 0, 0), (1, 0, 0, 1, 0, 0, 0, 0, 0)]
>>> assert_close(
...     rollout1,
...     tree.rollout_from_path((1, 0, 0, 0, 0, 0)),
...     intersection=True,
... )
True
property action_keys: List[NestedKey]

Action Keys。

返回用於從環境輸入中檢索 actions 的 keys。預設的 action key 是“action”。

返回:

表示 action keys 的字串或元組列表。

property done_keys: List[NestedKey]

Done Keys。

返回用於指示 episode 已結束的 keys。預設的 done keys 是“done”、“terminated”和“truncated”。這些 keys 可用於環境的輸出中,以指示 episode 的結束。

返回:

表示 done keys 的字串列表。

extend(rollout, *, return_node: bool = False)[原始檔]

將一個 rollout 新增到 forest。

僅在 rollout 相互分歧的點和 rollout 的終點將節點新增到樹中。

如果不存在與 rollout 的前幾個步驟匹配的現有樹,則新增一個新的樹。僅為最後一步建立一個節點。

如果存在匹配的現有樹,則將 rollout 新增到該樹中。如果在某個步驟中 rollout 與樹中所有其他 rollout 分歧,則在 rollout 分歧的步驟之前建立一個新節點,併為 rollout 的最後一步建立一個葉節點。如果 rollout 的所有步驟都與之前新增的 rollout 匹配,則沒有任何改變。如果 rollout 匹配到樹的葉節點,但繼續超出該節點,則該節點會擴充套件到 rollout 的末尾,並且不會建立新的節點。

引數:
  • rollout (TensorDict) – 要新增到 forest 的 rollout。

  • return_node (bool, 可選) – 如果為 True,該方法將返回新增的節點。預設為 False

返回:

新增到 forest 的節點。僅當

return_node 為 True 時才返回。

返回型別:

Tree

示例

>>> from torchrl.data import MCTSForest
>>> from tensordict import TensorDict
>>> import torch
>>> forest = MCTSForest()
>>> r0 = TensorDict({
...     'action': torch.tensor([1, 2, 3, 4, 5]),
...     'next': {'observation': torch.tensor([123, 392, 989, 809, 847])},
...     'observation': torch.tensor([  0, 123, 392, 989, 809])
... }, [5])
>>> r1 = TensorDict({
...     'action': torch.tensor([1, 2, 6, 7]),
...     'next': {'observation': torch.tensor([123, 392, 235,  38])},
...     'observation': torch.tensor([  0, 123, 392, 235])
... }, [4])
>>> td_root = r0[0].exclude("next")
>>> forest.extend(r0)
>>> forest.extend(r1)
>>> tree = forest.get_tree(td_root)
>>> print(tree)
Tree(
    count=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, is_shared=False),
    index=Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False),
    node_data=TensorDict(
        fields={
            observation: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False)},
        batch_size=torch.Size([]),
        device=cpu,
        is_shared=False),
    node_id=NonTensorData(data=0, batch_size=torch.Size([]), device=None),
    rollout=TensorDict(
        fields={
            action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False),
            next: TensorDict(
                fields={
                    observation: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False)},
                batch_size=torch.Size([2]),
                device=cpu,
                is_shared=False),
            observation: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False)},
        batch_size=torch.Size([2]),
        device=cpu,
        is_shared=False),
    subtree=Tree(
        _parent=NonTensorStack(
            [<weakref at 0x716eeb78fbf0; to 'TensorDict' at 0x...,
            batch_size=torch.Size([2]),
            device=None),
        count=Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int32, is_shared=False),
        hash=NonTensorStack(
            [4341220243998689835, 6745467818783115365],
            batch_size=torch.Size([2]),
            device=None),
        node_data=LazyStackedTensorDict(
            fields={
                observation: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False)},
            exclusive_fields={
            },
            batch_size=torch.Size([2]),
            device=cpu,
            is_shared=False,
            stack_dim=0),
        node_id=NonTensorStack(
            [1, 2],
            batch_size=torch.Size([2]),
            device=None),
        rollout=LazyStackedTensorDict(
            fields={
                action: Tensor(shape=torch.Size([2, -1]), device=cpu, dtype=torch.int64, is_shared=False),
                next: LazyStackedTensorDict(
                    fields={
                        observation: Tensor(shape=torch.Size([2, -1]), device=cpu, dtype=torch.int64, is_shared=False)},
                    exclusive_fields={
                    },
                    batch_size=torch.Size([2, -1]),
                    device=cpu,
                    is_shared=False,
                    stack_dim=0),
                observation: Tensor(shape=torch.Size([2, -1]), device=cpu, dtype=torch.int64, is_shared=False)},
            exclusive_fields={
            },
            batch_size=torch.Size([2, -1]),
            device=cpu,
            is_shared=False,
            stack_dim=0),
        wins=Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
        index=None,
        subtree=None,
        specs=None,
        batch_size=torch.Size([2]),
        device=None,
        is_shared=False),
    wins=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
    hash=None,
    _parent=None,
    specs=None,
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
get_keys_from_env(env: EnvBase)[原始檔]

給定一個環境,將缺失的 done、action 和 reward keys 寫入 Forest。

現有 keys 不會被覆蓋。

property observation_keys: List[NestedKey]

Observation Keys。

返回用於從環境輸出中檢索 observations 的 keys。預設的 observation key 是“observation”。

返回:

表示 observation keys 的字串或元組列表。

property reward_keys: List[NestedKey]

Reward Keys。

返回用於從環境輸出中檢索 rewards 的 keys。預設的 reward key 是“reward”。

返回:

表示 reward keys 的字串或元組列表。


© 版權所有 2022, Meta。

使用 Sphinx 構建,主題由 Read the Docs 提供。

文件

訪問 PyTorch 的全面開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源