TensorDictMaxValueWriter¶
- class torchrl.data.replay_buffers.TensorDictMaxValueWriter(rank_key=None, reduction: str = 'sum', **kwargs)[source]¶
一個 Writer 類,用於可組合回放緩衝區,根據某個排序鍵保留最優元素。
- 引數:
rank_key (str or tuple of str) – 用於對元素進行排序的鍵。預設為
("next", "reward")。reduction (str) – 如果排序鍵包含多個元素,則使用的歸約方法。可以是
"max"、"min"、"mean"、"median"或"sum"。
示例
>>> import torch >>> from tensordict import TensorDict >>> from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer, TensorDictMaxValueWriter >>> from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement >>> rb = TensorDictReplayBuffer( ... storage=LazyTensorStorage(1), ... sampler=SamplerWithoutReplacement(), ... batch_size=1, ... writer=TensorDictMaxValueWriter(rank_key="key"), ... ) >>> td = TensorDict({ ... "key": torch.tensor(range(10)), ... "obs": torch.tensor(range(10)) ... }, batch_size=10) >>> rb.extend(td) >>> print(rb.sample().get("obs").item()) 9 >>> td = TensorDict({ ... "key": torch.tensor(range(10, 20)), ... "obs": torch.tensor(range(10, 20)) ... }, batch_size=10) >>> rb.extend(td) >>> print(rb.sample().get("obs").item()) 19 >>> td = TensorDict({ ... "key": torch.tensor(range(10)), ... "obs": torch.tensor(range(10)) ... }, batch_size=10) >>> rb.extend(td) >>> print(rb.sample().get("obs").item()) 19
注意
這個類與多維儲存不相容。這並不意味著禁止儲存軌跡,而是儲存的軌跡必須按每個軌跡的基礎儲存。以下是一些該類的有效和無效用法示例。首先,一個扁平緩衝區,我們在其中儲存單個轉換
>>> from torchrl.data import TensorStorage >>> # Simplest use case: data comes in 1d and is stored as such >>> data = TensorDict({ ... "obs": torch.zeros(10, 3), ... "reward": torch.zeros(10, 1), ... }, batch_size=[10]) >>> rb = TensorDictReplayBuffer( ... storage=LazyTensorStorage(max_size=100), ... writer=TensorDictMaxValueWriter(rank_key="reward") ... ) >>> # We initialize the buffer: a total of 100 *transitions* can be stored >>> rb.extend(data) >>> # Samples 5 *transitions* at random >>> sample = rb.sample(5) >>> assert sample.shape == (5,)
其次,一個儲存軌跡的緩衝區。最大訊號在每個批次中聚合(例如,每個 rolluot 的獎勵被求和)
>>> # One can also store batches of data, each batch being a sub-trajectory >>> env = ParallelEnv(2, lambda: GymEnv("Pendulum-v1")) >>> # Get a batch of [2, 10] -- format is [Batch, Time] >>> rollout = env.rollout(max_steps=10) >>> rb = TensorDictReplayBuffer( ... storage=LazyTensorStorage(max_size=100), ... writer=TensorDictMaxValueWriter(rank_key="reward") ... ) >>> # We initialize the buffer: a total of 100 *trajectories* (!) can be stored >>> rb.extend(rollout) >>> # Sample 5 trajectories at random >>> sample = rb.sample(5) >>> assert sample.shape == (5, 10)
如果資料是批次形式的,但需要扁平緩衝區,我們可以在擴充套件緩衝區之前簡單地扁平化資料
>>> rb = TensorDictReplayBuffer( ... storage=LazyTensorStorage(max_size=100), ... writer=TensorDictMaxValueWriter(rank_key="reward") ... ) >>> # We initialize the buffer: a total of 100 *transitions* can be stored >>> rb.extend(rollout.reshape(-1)) >>> # Sample 5 trajectories at random >>> sample = rb.sample(5) >>> assert sample.shape == (5,)
無法建立一個沿時間維度擴充套件的緩衝區,這通常是使用批次軌跡緩衝區時推薦的方式。由於軌跡是重疊的,聚合獎勵值並進行比較很困難,甚至不可能。這個建構函式無效(注意 ndim 引數)
>>> rb = TensorDictReplayBuffer( ... storage=LazyTensorStorage(max_size=100, ndim=2), # Breaks! ... writer=TensorDictMaxValueWriter(rank_key="reward") ... )
- add(data: Any) int | torch.Tensor[source]¶
在適當的索引處插入單個數據元素,並返回該索引。
傳遞給此模組的資料中的
rank_key應被構造為[]。如果它具有更多維度,將使用reduction方法將其歸約到單個值。