快捷方式

SliceSamplerWithoutReplacement

class torchrl.data.replay_buffers.SliceSamplerWithoutReplacement(*, num_slices: int | None = None, slice_len: int | None = None, drop_last: bool = False, end_key: NestedKey | None = None, traj_key: NestedKey | None = None, ends: torch.Tensor | None = None, trajectories: torch.Tensor | None = None, truncated_key: NestedKey | None = ('next', 'truncated'), strict_length: bool = True, shuffle: bool = True, compile: bool | dict = False, use_gpu: bool | torch.device = False)[原始碼]

在給定開始和停止訊號的情況下,沿第一維無放回地取樣資料切片。

在此上下文中,無放回 意味著在計數器自動重置之前,同一個元素(不是軌跡)不會被重複取樣。然而,在單個取樣中,給定軌跡只會出現一個切片(見下面的示例)。

此類應與靜態回放緩衝區或在兩次回放緩衝區擴充套件之間使用。擴充套件回放緩衝區將重置取樣器,目前不允許連續無放回取樣。

注意

SliceSamplerWithoutReplacement 在檢索軌跡索引時可能很慢。為了加速其執行,優先使用 end_key 而不是 traj_key,並考慮以下關鍵字引數:compile, cache_valuesuse_gpu

關鍵字引數
  • drop_last (bool, 可選) – 如果為 True,則最後一個不完整的樣本(如果存在)將被丟棄。如果為 False,則會保留最後一個樣本。預設為 False

  • num_slices (int) – 要取樣的切片數量。批次大小必須大於或等於 num_slices 引數。與 slice_len 互斥。

  • slice_len (int) – 要取樣的切片長度。批次大小必須大於或等於 slice_len 引數且能被其整除。與 num_slices 互斥。

  • end_key (NestedKey, 可選) – 指示軌跡(或情節)結束的鍵。預設為 ("next", "done")

  • traj_key (NestedKey, 可選) – 指示軌跡的鍵。預設為 "episode"(在 TorchRL 的資料集中常用)。

  • ends (torch.Tensor, 可選) – 一個 1 維布林張量,包含執行結束訊號。當 end_keytraj_key 獲取成本很高,或者此訊號已準備好時使用。必須與 cache_values=True 一起使用,不能與 end_keytraj_key 結合使用。

  • trajectories (torch.Tensor, 可選) – 一個 1 維整數張量,包含執行 ID。當 end_keytraj_key 獲取成本很高,或者此訊號已準備好時使用。必須與 cache_values=True 一起使用,不能與 end_keytraj_key 結合使用。

  • truncated_key (NestedKey, 可選) – 如果不為 None,此引數指示應將截斷訊號寫入輸出資料的哪個位置。這用於向值估計器指示提供的軌跡中斷的位置。預設為 ("next", "truncated")。此功能僅適用於 TensorDictReplayBuffer 例項(否則截斷鍵將在 sample() 方法返回的 info 字典中返回)。

  • strict_length (bool, 可選) – 如果為 False,則允許長度小於 slice_len(或 batch_size // num_slices)的軌跡出現在批次中。如果為 True,則將過濾掉長度小於要求的軌跡。請注意,這可能導致實際 batch_size 小於請求的大小!可以使用 split_trajectories() 分割軌跡。預設為 True

  • shuffle (bool, 可選) – 如果為 False,則軌跡的順序不會被打亂。預設為 True

  • compile (booldict of kwargs, 可選) – 如果為 True,則 sample() 方法的瓶頸部分將使用 compile() 進行編譯。也可以使用此引數將關鍵字引數傳遞給 torch.compile。預設為 False

  • use_gpu (booltorch.device) – 如果為 True(或傳遞了裝置),則將使用加速器來檢索軌跡起始點的索引。當緩衝區內容很大時,這可以顯著加速取樣。預設為 False

注意

為了恢復儲存中的軌跡分割,SliceSamplerWithoutReplacement 將首先嚐試在儲存中查詢 traj_key 條目。如果找不到,將使用 end_key 重建情節。

示例

>>> import torch
>>> from tensordict import TensorDict
>>> from torchrl.data.replay_buffers import LazyMemmapStorage, TensorDictReplayBuffer
>>> from torchrl.data.replay_buffers.samplers import SliceSamplerWithoutReplacement
>>>
>>> rb = TensorDictReplayBuffer(
...     storage=LazyMemmapStorage(1000),
...     # asking for 10 slices for a total of 320 elements, ie, 10 trajectories of 32 transitions each
...     sampler=SliceSamplerWithoutReplacement(num_slices=10),
...     batch_size=320,
... )
>>> episode = torch.zeros(1000, dtype=torch.int)
>>> episode[:300] = 1
>>> episode[300:550] = 2
>>> episode[550:700] = 3
>>> episode[700:] = 4
>>> data = TensorDict(
...     {
...         "episode": episode,
...         "obs": torch.randn((3, 4, 5)).expand(1000, 3, 4, 5),
...         "act": torch.randn((20,)).expand(1000, 20),
...         "other": torch.randn((20, 50)).expand(1000, 20, 50),
...     }, [1000]
... )
>>> rb.extend(data)
>>> sample = rb.sample()
>>> # since we want trajectories of 32 transitions but there are only 4 episodes to
>>> # sample from, we only get 4 x 32 = 128 transitions in this batch
>>> print("sample:", sample)
>>> print("trajectories in sample", sample.get("episode").unique())

SliceSamplerWithoutReplacement 與大多數 TorchRL 資料集預設相容,並允許使用者以類似資料載入器的方式使用資料集

示例

>>> import torch
>>>
>>> from torchrl.data.datasets import RobosetExperienceReplay
>>> from torchrl.data import SliceSamplerWithoutReplacement
>>>
>>> torch.manual_seed(0)
>>> num_slices = 10
>>> dataid = list(RobosetExperienceReplay.available_datasets)[0]
>>> data = RobosetExperienceReplay(dataid, batch_size=320,
...     sampler=SliceSamplerWithoutReplacement(num_slices=num_slices))
>>> # the last sample is kept, since drop_last=False by default
>>> for i, batch in enumerate(data):
...     print(batch.get("episode").unique())
tensor([ 5,  6,  8, 11, 12, 14, 16, 17, 19, 24])
tensor([ 1,  2,  7,  9, 10, 13, 15, 18, 21, 22])
tensor([ 0,  3,  4, 20, 23])

當請求大量總樣本但軌跡數量少且跨度小時,批次中每個軌跡最多隻包含一個樣本

示例

>>> import torch
>>> from tensordict import TensorDict
>>> from torchrl.collectors.utils import split_trajectories
>>> from torchrl.data import ReplayBuffer, LazyTensorStorage, SliceSampler, SliceSamplerWithoutReplacement
>>>
>>> rb = ReplayBuffer(storage=LazyTensorStorage(max_size=1000),
...                   sampler=SliceSamplerWithoutReplacement(
...                       slice_len=5, traj_key="episode",strict_length=False
...                   ))
...
>>> ep_1 = TensorDict(
...     {"obs": torch.arange(100),
...     "episode": torch.zeros(100),},
...     batch_size=[100]
... )
>>> ep_2 = TensorDict(
...     {"obs": torch.arange(51),
...     "episode": torch.ones(51),},
...     batch_size=[51]
... )
>>> rb.extend(ep_1)
>>> rb.extend(ep_2)
>>>
>>> s = rb.sample(50)
>>> t = split_trajectories(s, trajectory_key="episode")
>>> print(t["obs"])
tensor([[14, 15, 16, 17, 18],
        [ 3,  4,  5,  6,  7]])
>>> print(t["episode"])
tensor([[0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1.]])
>>>
>>> s = rb.sample(50)
>>> t = split_trajectories(s, trajectory_key="episode")
>>> print(t["obs"])
tensor([[ 4,  5,  6,  7,  8],
        [26, 27, 28, 29, 30]])
>>> print(t["episode"])
tensor([[0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1.]])

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深入教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源