快捷方式

PrioritizedSliceSampler

class torchrl.data.replay_buffers.PrioritizedSliceSampler(max_capacity: int, alpha: float, beta: float, eps: float = 1e-08, dtype: torch.dtype = torch.float32, reduction: str = 'max', *, num_slices: int = None, slice_len: int = None, end_key: NestedKey | None = ('next', 'done'), traj_key: NestedKey | None = 'episode', ends: torch.Tensor | None = None, trajectories: torch.Tensor | None = None, cache_values: bool = False, truncated_key: NestedKey | None = ('next', 'truncated'), strict_length: bool = True, compile: bool | dict = False, span: bool | int | Tuple[bool | int, bool | int] = False, max_priority_within_buffer: bool = False)[source]

根據開始和停止訊號,沿第一維度使用優先採樣對資料切片進行取樣。

此類根據“Schaul, T.; Quan, J.; Antonoglou, I.; and Silver, D. 2015 年論文”中提出的優先順序權重,有替換地對子軌跡進行取樣。

優先經驗回放” (https://arxiv.org/abs/1511.05952)

更多資訊請參見 SliceSamplerPrioritizedSampler

警告

PrioritizedSliceSampler 將檢視單個轉換的優先順序,並據此對起始點進行取樣。這意味著如果低優先順序轉換緊隨更高優先順序的轉換,它們也可能出現在樣本中;而高優先順序但更靠近軌跡末尾的轉換,如果不能用作起始點,則可能永遠不會被取樣。目前,使用者有責任使用 update_priority() 方法聚合軌跡項的優先順序。

引數:
  • alpha (float) – 指數 α 決定了優先化的程度,α = 0 對應於均勻取樣的情況。

  • beta (float) – 重要性取樣的負指數。

  • eps (float, 可選) – 新增到優先順序中的 delta 值,以確保緩衝區不包含零優先順序。預設為 1e-8。

  • reduction (str, 可選) – 多維 tensordicts(即儲存的軌跡)的歸約方法。可以是“max”、“min”、“median”或“mean”之一。

關鍵字引數:
  • num_slices (int) – 要取樣的切片數量。批處理大小(batch-size)必須大於或等於 num_slices 引數。與 slice_len 互斥。

  • slice_len (int) – 要取樣的切片的長度。批處理大小(batch-size)必須大於或等於 slice_len 引數,並且可以被其整除。與 num_slices 互斥。

  • end_key (NestedKey, 可選) – 指示軌跡(或 episode)結束的鍵。預設為 ("next", "done")

  • traj_key (NestedKey, 可選) – 指示軌跡的鍵。預設為 "episode"(TorchRL 資料集中常用)。

  • ends (torch.Tensor, 可選) – 包含執行結束訊號的 1d 布林張量。當獲取 end_keytraj_key 代價較高,或者當此訊號已準備好時使用。必須與 cache_values=True 一起使用,且不能與 end_keytraj_key 結合使用。

  • trajectories (torch.Tensor, 可選) – 包含執行 ID 的 1d 整型張量。當獲取 end_keytraj_key 代價較高,或者當此訊號已準備好時使用。必須與 cache_values=True 一起使用,且不能與 end_keytraj_key 結合使用。

  • cache_values (bool, 可選) –

    用於靜態資料集。將快取軌跡的開始和結束訊號。即使在呼叫 extend 期間軌跡索引發生變化,也可以安全使用,因為此操作將清除快取。

    警告

    cache_values=True 在取樣器與由另一個緩衝區擴充套件的儲存一起使用時將不起作用。例如

    >>> buffer0 = ReplayBuffer(storage=storage,
    ...     sampler=SliceSampler(num_slices=8, cache_values=True),
    ...     writer=ImmutableWriter())
    >>> buffer1 = ReplayBuffer(storage=storage,
    ...     sampler=other_sampler)
    >>> # Wrong! Does not erase the buffer from the sampler of buffer0
    >>> buffer1.extend(data)
    

    警告

    cache_values=True 如果緩衝區在程序之間共享,一個程序負責寫入,一個程序負責取樣,則將無法按預期工作,因為清除快取只能在本地完成。

  • truncated_key (NestedKey, 可選) – 如果不是 None,此引數指示截斷訊號應寫入輸出資料中的位置。這用於向值估計器指示提供的軌跡中斷的位置。預設為 ("next", "truncated")。此功能僅適用於 TensorDictReplayBuffer 例項(否則截斷鍵將在 sample() 方法返回的資訊字典中返回)。

  • strict_length (bool, 可選) – 如果為 False,長度小於 slice_len(或 batch_size // num_slices)的軌跡將被允許出現在批處理中。如果為 True,則將過濾掉短於要求的軌跡。請注意,這可能導致實際的 batch_size 小於請求的大小!軌跡可以使用 split_trajectories() 函式進行分割。預設為 True

  • compile (booldict of kwargs, 可選) – 如果為 Truesample() 方法的瓶頸部分將使用 compile() 進行編譯。關鍵字引數也可以透過此引數傳遞給 torch.compile。預設為 False

  • span (bool, int, Tuple[bool | int, bool | int], 可選) – 如果提供,取樣的軌跡將跨越左側和/或右側。這意味著提供的元素數量可能少於所需數量。布林值表示每個軌跡至少會取樣一個元素。整數 i 表示每個取樣的軌跡至少會收集 slice_len - i 個樣本。使用元組可以精細控制左側(儲存軌跡的開始)和右側(儲存軌跡的結束)的跨度。

  • max_priority_within_buffer (bool, 可選) – 如果為 True,則在緩衝區內跟蹤最大優先順序。如果為 False,則最大優先順序跟蹤自採樣器例項化以來的最大值。預設為 False

示例

>>> import torch
>>> from torchrl.data.replay_buffers import TensorDictReplayBuffer, LazyMemmapStorage, PrioritizedSliceSampler
>>> from tensordict import TensorDict
>>> sampler = PrioritizedSliceSampler(max_capacity=9, num_slices=3, alpha=0.7, beta=0.9)
>>> rb = TensorDictReplayBuffer(storage=LazyMemmapStorage(9), sampler=sampler, batch_size=6)
>>> data = TensorDict(
...     {
...         "observation": torch.randn(9,16),
...         "action": torch.randn(9, 1),
...         "episode": torch.tensor([0,0,0,1,1,1,2,2,2], dtype=torch.long),
...         "steps": torch.tensor([0,1,2,0,1,2,0,1,2], dtype=torch.long),
...         ("next", "observation"): torch.randn(9,16),
...         ("next", "reward"): torch.randn(9,1),
...         ("next", "done"): torch.tensor([0,0,1,0,0,1,0,0,1], dtype=torch.bool).unsqueeze(1),
...     },
...     batch_size=[9],
... )
>>> rb.extend(data)
>>> sample, info = rb.sample(return_info=True)
>>> print("episode", sample["episode"].tolist())
episode [2, 2, 2, 2, 1, 1]
>>> print("steps", sample["steps"].tolist())
steps [1, 2, 0, 1, 1, 2]
>>> print("weight", info["_weight"].tolist())
weight [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
>>> priority = torch.tensor([0,3,3,0,0,0,1,1,1])
>>> rb.update_priority(torch.arange(0,9,1), priority=priority)
>>> sample, info = rb.sample(return_info=True)
>>> print("episode", sample["episode"].tolist())
episode [2, 2, 2, 2, 2, 2]
>>> print("steps", sample["steps"].tolist())
steps [1, 2, 0, 1, 0, 1]
>>> print("weight", info["_weight"].tolist())
weight [9.120110917137936e-06, 9.120110917137936e-06, 9.120110917137936e-06, 9.120110917137936e-06, 9.120110917137936e-06, 9.120110917137936e-06]
update_priority(index: Union[int, torch.Tensor], priority: Union[float, torch.Tensor], *, storage: TensorStorage | None = None) None

更新索引指向的資料的優先順序。

引數:
  • index (inttorch.Tensor) – 要更新優先順序的索引。

  • priority (Numbertorch.Tensor) – 索引元素的新的優先順序。

關鍵字引數:

storage (Storage, 可選) – 用於將 Nd 索引大小對映到 sum_tree 和 min_tree 的 1d 大小的儲存。僅在 index.ndim > 2 時需要。

文件

查閱 PyTorch 的全面開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲取問題解答

檢視資源