Tree¶
- class torchrl.data.Tree(count: 'int | torch.Tensor' = None, wins: 'int | torch.Tensor' = None, index: 'torch.Tensor | None' = None, hash: 'int | None' = None, node_id: 'int | None' = None, rollout: 'TensorDict | None' = None, node_data: 'TensorDict | None' = None, subtree: "'Tree'" = None, _parent: 'weakref.ref | List[weakref.ref] | None' = None, specs: 'Composite | None' = None, *, batch_size, device=None, names=None)[source]¶
- property branching_action: torch.Tensor | TensorDictBase | None¶
返回由此特定節點分支出的 Action。
- 返回:
如果節點沒有父節點,則返回 Tensor、TensorDict 或 None。
另請參閱
當 Rollout 資料包含單個 Step 時,這將等於
prev_action。另請參閱
所有與樹中給定節點(或 Observation)關聯的 Action.
- dumps(prefix: str | None = None, copy_existing: bool = False, *, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False) T¶
將 TensorDict 儲存到磁碟。
此函式是
memmap()的代理。
- edges() List[Tuple[int, int]][source]¶
獲取樹中的 Edge 列表。
每個 Edge 表示為兩個節點 ID 的元組:父節點 ID 和子節點 ID。樹使用廣度優先搜尋 (BFS) 遍歷,以確保訪問所有 Edge。
- 返回:
一個元組列表,其中每個元組包含一個父節點 ID 和一個子節點 ID。
- classmethod fields()¶
返回描述此 Dataclass 欄位的元組。
接受 Dataclass 或其例項。元組元素型別為 Field。
- classmethod from_tensordict(tensordict, non_tensordict=None, safe=True)¶
用於例項化新 Tensor 類別物件的 Tensor 類別包裝器。
- 引數:
tensordict (TensorDict) – Tensor 型別的字典
non_tensordict (dict) – 包含非 Tensor 和巢狀 Tensor 類別物件的字典
- property full_action_spec¶
樹的 Action Spec。
這是 Tree.specs[‘input_spec’, ‘full_action_spec’] 的別名。
- property full_done_spec¶
樹的 Done Spec。
這是 Tree.specs[‘output_spec’, ‘full_done_spec’] 的別名。
- property full_observation_spec¶
樹的 Observation Spec。
這是 Tree.specs[‘output_spec’, ‘full_observation_spec’] 的別名。
- property full_reward_spec¶
樹的 Reward Spec。
這是 Tree.specs[‘output_spec’, ‘full_reward_spec’] 的別名。
- property full_state_spec¶
樹的 State Spec。
這是 Tree.specs[‘input_spec’, ‘full_state_spec’] 的別名。
- get(key: NestedKey, *args, **kwargs)¶
獲取使用輸入 Key 儲存的值。
- 引數:
key (str, tuple of str) – 要查詢的 Key。如果是 str 元組,則等效於 getattr 的鏈式呼叫。
default – 如果在 Tensorclass 中找不到 Key 的預設值。
- 返回:
使用輸入 Key 儲存的值
- property is_terminal: bool | torch.Tensor¶
如果樹沒有子節點,則返回 True。
- classmethod load(prefix: str | Path, *args, **kwargs) T¶
從磁碟載入 TensorDict。
此類方法是
load_memmap()的代理。
- load_(prefix: str | Path, *args, **kwargs)¶
將 TensorDict 從磁碟載入到當前 TensorDict 中。
此類方法是
load_memmap_()的代理。
- classmethod load_memmap(prefix: str | Path, device: torch.device | None = None, non_blocking: bool = False, *, out: TensorDictBase | None = None) T¶
從磁碟載入 Memory-mapped TensorDict。
- 引數:
prefix (str or Path to folder) – 應從中獲取已儲存 TensorDict 的資料夾路徑。
device (torch.device or equivalent, optional) – 如果提供,資料將非同步地轉換為該 Device。支援 “meta” Device,在這種情況下資料不會被載入,而是建立一組空的“meta” Tensor。這對於瞭解總模型大小和結構而無需實際開啟任何檔案很有用。
non_blocking (bool, optional) – 如果為
True,在 Device 上載入 Tensor 後不會呼叫 synchronize。預設為False。out (TensorDictBase, optional) – 可選的 TensorDict,資料應寫入其中。
示例
>>> from tensordict import TensorDict >>> td = TensorDict.fromkeys(["a", "b", "c", ("nested", "e")], 0) >>> td.memmap("./saved_td") >>> td_load = TensorDict.load_memmap("./saved_td") >>> assert (td == td_load).all()
此方法還支援載入巢狀的 TensorDict。
示例
>>> nested = TensorDict.load_memmap("./saved_td/nested") >>> assert nested["e"] == 0
TensorDict 也可以載入到“meta” Device 上,或者作為 Fake Tensor 載入。
示例
>>> import tempfile >>> td = TensorDict({"a": torch.zeros(()), "b": {"c": torch.zeros(())}}) >>> with tempfile.TemporaryDirectory() as path: ... td.save(path) ... td_load = TensorDict.load_memmap(path, device="meta") ... print("meta:", td_load) ... from torch._subclasses import FakeTensorMode ... with FakeTensorMode(): ... td_load = TensorDict.load_memmap(path) ... print("fake:", td_load) meta: TensorDict( fields={ a: Tensor(shape=torch.Size([]), device=meta, dtype=torch.float32, is_shared=False), b: TensorDict( fields={ c: Tensor(shape=torch.Size([]), device=meta, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=meta, is_shared=False)}, batch_size=torch.Size([]), device=meta, is_shared=False) fake: TensorDict( fields={ a: FakeTensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), b: TensorDict( fields={ c: FakeTensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=cpu, is_shared=False)}, batch_size=torch.Size([]), device=cpu, is_shared=False)
- load_state_dict(state_dict: dict[str, Any], strict=True, assign=False, from_flatten=False)¶
嘗試將 State_dict 就地載入到目標 Tensorclass 上。
- classmethod make_node(data: TensorDictBase, *, device: torch.device | None = None, batch_size: torch.Size | None = None, specs: Composite | None = None) Tree[source]¶
根據給定資料建立一個新節點。
- memmap(prefix: str | None = None, copy_existing: bool = False, *, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False, existsok: bool = True) T¶
將所有 Tensor 寫入新 TensorDict 中相應的 Memory-mapped Tensor。
- 引數:
prefix (str) – Memory-mapped Tensor 儲存所在的目錄字首。目錄樹結構將模仿 TensorDict 的結構。
copy_existing (bool) – 如果為 False(預設),如果 TensorDict 中的某個條目已經是儲存在磁碟上的 Tensor 且具有關聯檔案,但未按照 Prefix 儲存在正確位置,則會引發異常。如果為
True,任何現有 Tensor 都將被複制到新位置。
- 關鍵字引數:
num_threads (int, optional) – 用於寫入 Memmap Tensor 的執行緒數量。預設為 0。
return_early (bool, optional) – 如果為
True且num_threads>0,該方法將返回 TensorDict 的 Future。share_non_tensor (bool, optional) – 如果為
True,非 Tensor 資料將在程序間共享,並且在單個節點內任意 worker 上的寫入操作(如原地更新或設定)將更新所有其他 worker 上的值。如果非 Tensor 葉子節點的數量很高(例如,共享大型非 Tensor 資料堆疊),這可能導致 OOM 或類似錯誤。預設為False。existsok (bool, optional) – 如果為
False,如果相同路徑下已存在 Tensor,則會引發異常。預設為True。
TensorDict 隨後被鎖定,這意味著任何非原地寫入操作(例如,重新命名、設定或移除條目)將丟擲異常。一旦 TensorDict 解鎖,記憶體對映屬性將變為
False,因為跨程序身份不再保證。- 返回:
如果
return_early=False,則返回一個 Tensor 儲存在磁碟上的新 tensordict;否則返回一個TensorDictFuture例項。
注意
以這種方式序列化對於深度巢狀的 tensordict 來說可能很慢,因此不建議在訓練迴圈內呼叫此方法。
- memmap_(prefix: str | None = None, copy_existing: bool = False, *, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False, existsok: bool = True) T¶
將所有 Tensor 原地寫入到相應的記憶體對映 Tensor。
- 引數:
prefix (str) – Memory-mapped Tensor 儲存所在的目錄字首。目錄樹結構將模仿 TensorDict 的結構。
copy_existing (bool) – 如果為 False(預設),如果 TensorDict 中的某個條目已經是儲存在磁碟上的 Tensor 且具有關聯檔案,但未按照 Prefix 儲存在正確位置,則會引發異常。如果為
True,任何現有 Tensor 都將被複制到新位置。
- 關鍵字引數:
num_threads (int, optional) – 用於寫入 Memmap Tensor 的執行緒數量。預設為 0。
return_early (bool, optional) – 如果為
True且num_threads>0,該方法將返回一個 tensordict 的 future。可以透過使用 future.result() 來查詢返回的 tensordict。share_non_tensor (bool, optional) – 如果為
True,非 Tensor 資料將在程序間共享,並且在單個節點內任意 worker 上的寫入操作(如原地更新或設定)將更新所有其他 worker 上的值。如果非 Tensor 葉子節點的數量很高(例如,共享大型非 Tensor 資料堆疊),這可能導致 OOM 或類似錯誤。預設為False。existsok (bool, optional) – 如果為
False,如果相同路徑下已存在 Tensor,則會引發異常。預設為True。
TensorDict 隨後被鎖定,這意味著任何非原地寫入操作(例如,重新命名、設定或移除條目)將丟擲異常。一旦 TensorDict 解鎖,記憶體對映屬性將變為
False,因為跨程序身份不再保證。- 返回:
如果
return_early=False,返回自身;否則返回一個TensorDictFuture例項。
注意
以這種方式序列化對於深度巢狀的 tensordict 來說可能很慢,因此不建議在訓練迴圈內呼叫此方法。
- memmap_like(prefix: str | None = None, copy_existing: bool = False, *, existsok: bool = True, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False) T¶
建立一個具有與原始 tensordict 相同形狀的無內容的記憶體對映 tensordict。
- 引數:
prefix (str) – Memory-mapped Tensor 儲存所在的目錄字首。目錄樹結構將模仿 TensorDict 的結構。
copy_existing (bool) – 如果為 False(預設),如果 TensorDict 中的某個條目已經是儲存在磁碟上的 Tensor 且具有關聯檔案,但未按照 Prefix 儲存在正確位置,則會引發異常。如果為
True,任何現有 Tensor 都將被複制到新位置。
- 關鍵字引數:
num_threads (int, optional) – 用於寫入 Memmap Tensor 的執行緒數量。預設為 0。
return_early (bool, optional) – 如果為
True且num_threads>0,該方法將返回 TensorDict 的 Future。share_non_tensor (bool, optional) – 如果為
True,非 Tensor 資料將在程序間共享,並且在單個節點內任意 worker 上的寫入操作(如原地更新或設定)將更新所有其他 worker 上的值。如果非 Tensor 葉子節點的數量很高(例如,共享大型非 Tensor 資料堆疊),這可能導致 OOM 或類似錯誤。預設為False。existsok (bool, optional) – 如果為
False,如果相同路徑下已存在 Tensor,則會引發異常。預設為True。
TensorDict 隨後被鎖定,這意味著任何非原地寫入操作(例如,重新命名、設定或移除條目)將丟擲異常。一旦 TensorDict 解鎖,記憶體對映屬性將變為
False,因為跨程序身份不再保證。- 返回:
如果
return_early=False,則返回一個數據儲存為記憶體對映 tensor 的新TensorDict例項;否則返回一個TensorDictFuture例項。
注意
這是將一組大型緩衝區寫入磁碟的推薦方法,因為
memmap_()將複製資訊,這對於大型內容可能很慢。示例
>>> td = TensorDict({ ... "a": torch.zeros((3, 64, 64), dtype=torch.uint8), ... "b": torch.zeros(1, dtype=torch.int64), ... }, batch_size=[]).expand(1_000_000) # expand does not allocate new memory >>> buffer = td.memmap_like("/path/to/dataset")
- memmap_refresh_()¶
如果記憶體對映 tensordict 具有
saved_path,則重新整理其內容。如果沒有與其關聯的路徑,此方法將引發異常。
- property node_observation: torch.Tensor | TensorDictBase¶
返回與此特定節點關聯的觀測值。
這是定義節點在分支發生之前的觀測值(或觀測值集合)。如果節點包含
rollout()屬性,則節點觀測值通常與上次執行的操作產生的觀測值相同,即node.rollout[..., -1]["next", "observation"]。如果與樹的規格關聯的觀測值鍵不止一個,則返回一個
TensorDict例項。為了更一致的表示,請參閱
node_observations。
- property node_observations: torch.Tensor | TensorDictBase¶
返回以 TensorDict 格式表示的、與此特定節點關聯的觀測值。
這是定義節點在分支發生之前的觀測值(或觀測值集合)。如果節點包含
rollout()屬性,則節點觀測值通常與上次執行的操作產生的觀測值相同,即node.rollout[..., -1]["next", "observation"]。如果與樹的規格關聯的觀測值鍵不止一個,則返回一個
TensorDict例項。為了更一致的表示,請參閱
node_observations。
- property num_children: int¶
此節點的子節點數量。
等於
self.subtree堆疊中的元素數量。
- num_vertices(*, count_repeat: bool = False) int[source]¶
返回 Tree 中唯一頂點的數量。
- 關鍵字引數:
count_repeat (bool, optional) –
確定是否計算重複頂點。
如果為
False,則僅計算每個唯一頂點一次。如果為
True,如果頂點出現在不同的路徑中,則多次計算。
預設為
False。- 返回:
Tree 中唯一頂點的數量。
- 返回型別:
int
- property parent: Tree | None¶
節點的父節點。
如果節點有父節點並且此物件仍然存在於 python 工作空間中,此屬性將返回它。
對於重新分支的樹,此屬性可能返回一個樹堆疊,其中堆疊中的每個索引對應於不同的父節點。
注意
parent 屬性的內容將匹配,但身份不匹配:tensorclass 物件是使用相同的 tensor(即,指向相同記憶體位置的 tensor)重建的。
- 返回:
如果父節點資料超出作用域或節點是根節點,則返回包含父節點資料的
Tree,否則返回None。
- plot(backend: str = 'plotly', figure: str = 'tree', info: List[str] = None, make_labels: Callable[[Any, ...], Any] | None = None)[source]¶
使用指定的後端和圖型別繪製樹的視覺化圖。
- 引數:
backend – 要使用的繪圖後端。目前僅支援 'plotly'。
figure – 要繪製的圖型別。可以是 'tree' 或 'box'。
info – 要包含在圖中的附加資訊列表(目前未使用)。
make_labels – 一個可選函式,用於為繪圖生成自定義標籤。
- 引發:
NotImplementedError – 如果指定了不受支援的後端或圖型別。
- property prev_action: torch.Tensor | TensorDictBase | None¶
就在生成此節點的觀測值之前執行的操作。
- 返回:
如果節點沒有父節點,則返回 Tensor、TensorDict 或 None。
另請參閱
只要 rollout 資料包含單個步驟,這將等於
branching_action。另請參閱
所有與樹中給定節點(或 Observation)關聯的 Action.
- rollout_from_path(path: Tuple[int]) TensorDictBase | None[source]¶
檢索沿著樹中給定路徑的 rollout 資料。
對於路徑中的每個節點,rollout 資料沿著最後一個維度(dim=-1)連線。如果沿著路徑未找到 rollout 資料,則返回
None。- 引數:
path – 一個表示樹中路徑的整數元組。
- 返回:
沿著路徑連線後的 rollout 資料,如果未找到資料則返回 None。
- save(prefix: str | None = None, copy_existing: bool = False, *, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False) T¶
將 TensorDict 儲存到磁碟。
此函式是
memmap()的代理。
- property selected_actions: torch.Tensor | TensorDictBase | None¶
返回一個包含從此節點分支出去的所有選定操作的 tensor。
- set(key: NestedKey, value: Any, inplace: bool = False, non_blocking: bool = False)¶
設定一個新的鍵值對。
- 引數:
key (str, tuple of str) – 要設定的鍵的名稱。如果是字串元組,則等同於鏈式呼叫 getattr 後跟一個最終的 setattr。
value (Any) – 要儲存在 tensorclass 中的值。
inplace (bool, optional) – 如果為
True,set 將嘗試原地更新值。如果為False或鍵不存在,值將被簡單寫入到其目的地。
- 返回:
自身
- state_dict(destination=None, prefix='', keep_vars=False, flatten=False) dict[str, Any]¶
返回一個 state_dict 字典,可用於儲存和載入 tensorclass 中的資料。
- to_tensordict(*, retain_none: bool | None = None) TensorDict¶
將 tensorclass 轉換為常規的 TensorDict。
複製所有條目。記憶體對映和共享記憶體 tensor 被轉換為常規 tensor。
- 引數:
retain_none (bool) –
如果為
True,None值將被寫入 tensordict 中。否則將被丟棄。預設:True。注意
從 v0.8 開始,預設值將更改為
False。- 返回:
一個包含與 tensorclass 相同值的新 TensorDict 物件。
- unbind(dim: int)¶
返回一個由索引 tensorclass 例項組成的元組,這些例項沿指定維度解除繫結。
結果 tensorclass 例項將共享初始 tensorclass 例項的儲存。
- valid_paths()[source]¶
生成樹中的所有有效路徑。
有效路徑是從根節點開始並在葉子節點結束的子節點索引序列。每個路徑表示為一個整數元組,其中每個整數對應於一個子節點的索引。
- 生成:
tuple – 樹中的一個有效路徑。
- vertices(*, key_type: Literal['id', 'hash', 'path'] = 'hash') Dict[int | Tuple[int], Tree][source]¶
返回一個包含 Tree 頂點的對映。
- 關鍵字引數:
key_type (Literal["id", "hash", "path"], optional) –
指定用於頂點的鍵的型別。
"id": 使用頂點 ID 作為鍵。
"hash": 使用頂點的雜湊值作為鍵。
- "path": 使用到頂點的路徑作為鍵。這可能導致字典的長度比使用 "id" 或 "hash" 時更長,因為同一個節點可能是多條軌跡的一部分。
預設為
"hash"。
(注:原文此處有額外一句 "Defaults to an empty string, which may imply a default behavior.",與前一句矛盾且在此處無上下文,已省略不譯。)
- 返回:
一個將鍵對映到 Tree 頂點的字典。
- 返回型別:
Dict[int | Tuple[int], Tree]
- property visits: int | torch.Tensor¶
返回與此特定節點關聯的訪問次數。
這是
count屬性的別名。