快捷方式

TensorDictMaxValueWriter

class torchrl.data.replay_buffers.TensorDictMaxValueWriter(rank_key=None, reduction: str = 'sum', **kwargs)[source]

一個 Writer 類,用於可組合回放緩衝區,根據某個排序鍵保留最優元素。

引數:
  • rank_key (str or tuple of str) – 用於對元素進行排序的鍵。預設為 ("next", "reward")

  • reduction (str) – 如果排序鍵包含多個元素,則使用的歸約方法。可以是 "max""min""mean""median""sum"

示例

>>> import torch
>>> from tensordict import TensorDict
>>> from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer, TensorDictMaxValueWriter
>>> from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
>>> rb = TensorDictReplayBuffer(
...     storage=LazyTensorStorage(1),
...     sampler=SamplerWithoutReplacement(),
...     batch_size=1,
...     writer=TensorDictMaxValueWriter(rank_key="key"),
... )
>>> td = TensorDict({
...     "key": torch.tensor(range(10)),
...     "obs": torch.tensor(range(10))
... }, batch_size=10)
>>> rb.extend(td)
>>> print(rb.sample().get("obs").item())
9
>>> td = TensorDict({
...     "key": torch.tensor(range(10, 20)),
...     "obs": torch.tensor(range(10, 20))
... }, batch_size=10)
>>> rb.extend(td)
>>> print(rb.sample().get("obs").item())
19
>>> td = TensorDict({
...     "key": torch.tensor(range(10)),
...     "obs": torch.tensor(range(10))
... }, batch_size=10)
>>> rb.extend(td)
>>> print(rb.sample().get("obs").item())
19

注意

這個類與多維儲存不相容。這並不意味著禁止儲存軌跡,而是儲存的軌跡必須按每個軌跡的基礎儲存。以下是一些該類的有效和無效用法示例。首先,一個扁平緩衝區,我們在其中儲存單個轉換

>>> from torchrl.data import TensorStorage
>>> # Simplest use case: data comes in 1d and is stored as such
>>> data = TensorDict({
...     "obs": torch.zeros(10, 3),
...     "reward": torch.zeros(10, 1),
... }, batch_size=[10])
>>> rb = TensorDictReplayBuffer(
...     storage=LazyTensorStorage(max_size=100),
...     writer=TensorDictMaxValueWriter(rank_key="reward")
... )
>>> # We initialize the buffer: a total of 100 *transitions* can be stored
>>> rb.extend(data)
>>> # Samples 5 *transitions* at random
>>> sample = rb.sample(5)
>>> assert sample.shape == (5,)

其次,一個儲存軌跡的緩衝區。最大訊號在每個批次中聚合(例如,每個 rolluot 的獎勵被求和)

>>> # One can also store batches of data, each batch being a sub-trajectory
>>> env = ParallelEnv(2, lambda: GymEnv("Pendulum-v1"))
>>> # Get a batch of [2, 10] -- format is [Batch, Time]
>>> rollout = env.rollout(max_steps=10)
>>> rb = TensorDictReplayBuffer(
...     storage=LazyTensorStorage(max_size=100),
...     writer=TensorDictMaxValueWriter(rank_key="reward")
... )
>>> # We initialize the buffer: a total of 100 *trajectories* (!) can be stored
>>> rb.extend(rollout)
>>> # Sample 5 trajectories at random
>>> sample = rb.sample(5)
>>> assert sample.shape == (5, 10)

如果資料是批次形式的,但需要扁平緩衝區,我們可以在擴充套件緩衝區之前簡單地扁平化資料

>>> rb = TensorDictReplayBuffer(
...     storage=LazyTensorStorage(max_size=100),
...     writer=TensorDictMaxValueWriter(rank_key="reward")
... )
>>> # We initialize the buffer: a total of 100 *transitions* can be stored
>>> rb.extend(rollout.reshape(-1))
>>> # Sample 5 trajectories at random
>>> sample = rb.sample(5)
>>> assert sample.shape == (5,)

無法建立一個沿時間維度擴充套件的緩衝區,這通常是使用批次軌跡緩衝區時推薦的方式。由於軌跡是重疊的,聚合獎勵值並進行比較很困難,甚至不可能。這個建構函式無效(注意 ndim 引數)

>>> rb = TensorDictReplayBuffer(
...     storage=LazyTensorStorage(max_size=100, ndim=2),  # Breaks!
...     writer=TensorDictMaxValueWriter(rank_key="reward")
... )
add(data: Any) int | torch.Tensor[source]

在適當的索引處插入單個數據元素,並返回該索引。

傳遞給此模組的資料中的 rank_key 應被構造為 []。如果它具有更多維度,將使用 reduction 方法將其歸約到單個值。

extend(data: TensorDictBase) None[source]

在適當的索引處插入一系列資料點。

傳遞給此模組的資料中的 rank_key 應被構造為 [B]。如果它具有更多維度,將使用 reduction 方法將其歸約到單個值。

get_insert_index(data: Any) int[source]

返回應該插入資料的索引,如果資料不應被插入,則返回 None

文件

查閱全面的 PyTorch 開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲取問題解答

檢視資源