快捷方式

PrioritizedSampler

class torchrl.data.replay_buffers.PrioritizedSampler(max_capacity: int, alpha: float, beta: float, eps: float = 1e-08, dtype: dtype =torch.float32, reduction: str = 'max', max_priority_within_buffer: bool = False)[source]

用於回放緩衝區的優先採樣器。

發表在 “Schaul, T.; Quan, J.; Antonoglou, I.; and Silver, D. 2015. Prioritized experience replay.” 中 (https://arxiv.org/abs/1511.05952)

引數:
  • max_capacity (int) – 緩衝區的最大容量。

  • alpha (float) – 指數 α 決定了優先順序的使用程度,α = 0 對應於均勻取樣的情況。

  • beta (float) – 重要性取樣的負指數。

  • eps (float, optional) – 新增到優先順序上的 delta 值,以確保緩衝區不包含零優先順序。預設為 1e-8。

  • reduction (str, optional) – 用於多維 tensordict(即儲存的軌跡)的縮減方法。可以是 “max”、“min”、“median” 或 “mean” 之一。

  • max_priority_within_buffer (bool, optional) – 如果為 True,則在緩衝區內跟蹤最大優先順序。如果為 False,則最大優先順序跟蹤自採樣器例項化以來的最大值。

示例

>>> from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage, PrioritizedSampler
>>> from tensordict import TensorDict
>>> rb = ReplayBuffer(storage=LazyTensorStorage(10), sampler=PrioritizedSampler(max_capacity=10, alpha=1.0, beta=1.0))
>>> priority = torch.tensor([0, 1000])
>>> data_0 = TensorDict({"reward": 0, "obs": [0], "action": [0], "priority": priority[0]}, [])
>>> data_1 = TensorDict({"reward": 1, "obs": [1], "action": [2], "priority": priority[1]}, [])
>>> rb.add(data_0)
>>> rb.add(data_1)
>>> rb.update_priority(torch.tensor([0, 1]), priority=priority)
>>> sample, info = rb.sample(10, return_info=True)
>>> print(sample)
TensorDict(
        fields={
            action: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.int64, is_shared=False),
            obs: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.int64, is_shared=False),
            priority: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.int64, is_shared=False),
            reward: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.int64, is_shared=False)},
        batch_size=torch.Size([10]),
        device=cpu,
        is_shared=False)
>>> print(info)
{'_weight': array([1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11,
       1.e-11, 1.e-11], dtype=float32), 'index': array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])}

注意

使用 TensorDictReplayBuffer 可以簡化更新優先順序的過程

>>> from torchrl.data.replay_buffers import TensorDictReplayBuffer as TDRB, LazyTensorStorage, PrioritizedSampler
>>> from tensordict import TensorDict
>>> rb = TDRB(
...     storage=LazyTensorStorage(10),
...     sampler=PrioritizedSampler(max_capacity=10, alpha=1.0, beta=1.0),
...     priority_key="priority",  # This kwarg isn't present in regular RBs
... )
>>> priority = torch.tensor([0, 1000])
>>> data_0 = TensorDict({"reward": 0, "obs": [0], "action": [0], "priority": priority[0]}, [])
>>> data_1 = TensorDict({"reward": 1, "obs": [1], "action": [2], "priority": priority[1]}, [])
>>> data = torch.stack([data_0, data_1])
>>> rb.extend(data)
>>> rb.update_priority(data)  # Reads the "priority" key as indicated in the constructor
>>> sample, info = rb.sample(10, return_info=True)
>>> print(sample['index'])  # The index is packed with the tensordict
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
update_priority(index: Union[int, torch.Tensor], priority: Union[float, torch.Tensor], *, storage: TensorStorage | None = None) None[source]

更新索引指向的資料的優先順序。

引數:
  • index (int or torch.Tensor) – 需要更新的優先順序的索引。

  • priority (Number or torch.Tensor) – 索引元素的新的優先順序。

關鍵字引數:

storage (Storage, optional) – 用於將 Nd 索引大小對映到 sum_tree 和 min_tree 的 1d 大小的儲存。僅當 index.ndim > 2 時需要。

文件

查閱 PyTorch 的完整開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源