快捷方式

SliceSampler

class torchrl.data.replay_buffers.SliceSampler(*, num_slices: int = None, slice_len: int = None, end_key: NestedKey | None = None, traj_key: NestedKey | None = None, ends: torch.Tensor | None = None, trajectories: torch.Tensor | None = None, cache_values: bool = False, truncated_key: NestedKey | None =('next', 'truncated'), strict_length: bool = True, compile: bool | dict = False, span: bool | int | Tuple[bool | int, bool | int] = False, use_gpu: torch.device | bool = False)[source]

根據起始和停止訊號,沿著第一個維度取樣資料切片。

此類有放回地取樣子軌跡。對於無放回取樣版本,請參見 SliceSamplerWithoutReplacement

注意

SliceSampler 檢索軌跡索引可能很慢。為了加速其執行,優先使用 end_key 而非 traj_key,並考慮以下關鍵字引數:compilecache_valuesuse_gpu

關鍵字引數:
  • num_slices (int) – 要取樣的切片數量。批次大小必須大於或等於 num_slices 引數。與 slice_len 互斥。

  • slice_len (int) – 要取樣的切片長度。批次大小必須大於或等於 slice_len 引數,並且必須能被其整除。與 num_slices 互斥。

  • end_key (NestedKey, 可選) – 指示軌跡(或片段)結束的鍵。預設為 ("next", "done")

  • traj_key (NestedKey, 可選) – 指示軌跡的鍵。預設為 "episode"(在 TorchRL 的資料集中常用)。

  • ends (torch.Tensor, 可選) – 包含執行結束訊號的一維布林張量。在獲取 end_keytraj_key 成本較高或此訊號易於獲得時使用。必須與 cache_values=True 一起使用,且不能與 end_keytraj_key 同時使用。如果提供此引數,則假定儲存已滿,並且如果 ends 張量的最後一個元素為 False,則同一軌跡跨越結束和開始。

  • trajectories (torch.Tensor, 可選) – 包含執行 ID 的一維整數張量。在獲取 end_keytraj_key 成本較高或此訊號易於獲得時使用。必須與 cache_values=True 一起使用,且不能與 end_keytraj_key 同時使用。如果提供此引數,則假定儲存已滿,並且如果軌跡張量的最後一個元素與第一個元素相同,則同一軌跡跨越結束和開始。

  • cache_values (bool, 可選) –

    與靜態資料集一起使用。將快取軌跡的起始和結束訊號。即使在呼叫 extend 期間軌跡索引發生變化,此操作也會擦除快取,因此可以安全使用。

    警告

    cache_values=True 在取樣器與由另一個緩衝區擴充套件的儲存一起使用時將無法正常工作。例如:

    >>> buffer0 = ReplayBuffer(storage=storage,
    ...     sampler=SliceSampler(num_slices=8, cache_values=True),
    ...     writer=ImmutableWriter())
    >>> buffer1 = ReplayBuffer(storage=storage,
    ...     sampler=other_sampler)
    >>> # Wrong! Does not erase the buffer from the sampler of buffer0
    >>> buffer1.extend(data)
    

    警告

    如果緩衝區在程序之間共享,且一個程序負責寫入,一個程序負責取樣,則 cache_values=True 將無法按預期工作,因為擦除快取只能在本地完成。

  • truncated_key (NestedKey, 可選) – 如果不是 None,此引數指示應在輸出資料中寫入截斷訊號的位置。這用於向價值估計器指示提供的軌跡在哪裡中斷。預設為 ("next", "truncated")。此功能僅適用於 TensorDictReplayBuffer 例項(否則截斷鍵在 sample() 方法返回的資訊字典中返回)。

  • strict_length (bool, 可選) – 如果為 False,則允許長度短於 slice_len(或 batch_size // num_slices)的軌跡出現在批次中。如果為 True,則會過濾掉短於所需長度的軌跡。請注意,這可能導致實際 batch_size 小於請求的大小!可以使用 split_trajectories() 分割軌跡。預設為 True

  • compile (bool or dict of kwargs, 可選) – 如果為 True,則 sample() 方法的瓶頸將使用 compile() 進行編譯。關鍵字引數也可以透過此引數傳遞給 torch.compile。預設為 False

  • span (bool, int, Tuple[bool | int, bool | int], 可選) – 如果提供此引數,取樣的軌跡將跨越左側和/或右側。這意味著提供的元素可能少於所需數量。布林值表示每個軌跡至少會取樣一個元素。整數 i 表示每個取樣軌跡至少會收集 slice_len - i 個樣本。使用元組可以對左側(儲存軌跡的開頭)和右側(儲存軌跡的末尾)的跨越進行細粒度控制。

  • use_gpu (bool or torch.device) – 如果為 True(或傳遞了裝置),則將使用加速器來檢索軌跡起始的索引。當緩衝區內容較大時,這可以顯著加速取樣。預設為 False

注意

為了恢復儲存中的軌跡分割,SliceSampler 將首先嚐試在儲存中找到 traj_key 條目。如果找不到,將使用 end_key 重建片段。

注意

當使用 strict_length=False 時,建議使用 split_trajectories() 來分割取樣到的軌跡。然而,如果來自同一片段的兩個樣本相鄰放置,這可能會產生錯誤的結果。為了避免這個問題,請考慮以下解決方案之一:

  • TensorDictReplayBuffer 例項與切片取樣器一起使用

    >>> import torch
    >>> from tensordict import TensorDict
    >>> from torchrl.collectors.utils import split_trajectories
    >>> from torchrl.data import TensorDictReplayBuffer, ReplayBuffer, LazyTensorStorage, SliceSampler, SliceSamplerWithoutReplacement
    >>>
    >>> rb = TensorDictReplayBuffer(storage=LazyTensorStorage(max_size=1000),
    ...                   sampler=SliceSampler(
    ...                       slice_len=5, traj_key="episode",strict_length=False,
    ...                   ))
    ...
    >>> ep_1 = TensorDict(
    ...     {"obs": torch.arange(100),
    ...     "episode": torch.zeros(100),},
    ...     batch_size=[100]
    ... )
    >>> ep_2 = TensorDict(
    ...     {"obs": torch.arange(4),
    ...     "episode": torch.ones(4),},
    ...     batch_size=[4]
    ... )
    >>> rb.extend(ep_1)
    >>> rb.extend(ep_2)
    >>>
    >>> s = rb.sample(50)
    >>> print(s)
    TensorDict(
        fields={
            episode: Tensor(shape=torch.Size([46]), device=cpu, dtype=torch.float32, is_shared=False),
            index: Tensor(shape=torch.Size([46, 1]), device=cpu, dtype=torch.int64, is_shared=False),
            next: TensorDict(
                fields={
                    done: Tensor(shape=torch.Size([46, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                    terminated: Tensor(shape=torch.Size([46, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                    truncated: Tensor(shape=torch.Size([46, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
                batch_size=torch.Size([46]),
                device=cpu,
                is_shared=False),
            obs: Tensor(shape=torch.Size([46]), device=cpu, dtype=torch.int64, is_shared=False)},
        batch_size=torch.Size([46]),
        device=cpu,
        is_shared=False)
    >>> t = split_trajectories(s, done_key="truncated")
    >>> print(t["obs"])
    tensor([[73, 74, 75, 76, 77],
            [ 0,  1,  2,  3,  0],
            [ 0,  1,  2,  3,  0],
            [41, 42, 43, 44, 45],
            [ 0,  1,  2,  3,  0],
            [67, 68, 69, 70, 71],
            [27, 28, 29, 30, 31],
            [80, 81, 82, 83, 84],
            [17, 18, 19, 20, 21],
            [ 0,  1,  2,  3,  0]])
    >>> print(t["episode"])
    tensor([[0., 0., 0., 0., 0.],
            [1., 1., 1., 1., 0.],
            [1., 1., 1., 1., 0.],
            [0., 0., 0., 0., 0.],
            [1., 1., 1., 1., 0.],
            [0., 0., 0., 0., 0.],
            [0., 0., 0., 0., 0.],
            [0., 0., 0., 0., 0.],
            [0., 0., 0., 0., 0.],
            [1., 1., 1., 1., 0.]])
    
  • 使用 SliceSamplerWithoutReplacement

    >>> import torch
    >>> from tensordict import TensorDict
    >>> from torchrl.collectors.utils import split_trajectories
    >>> from torchrl.data import ReplayBuffer, LazyTensorStorage, SliceSampler, SliceSamplerWithoutReplacement
    >>>
    >>> rb = ReplayBuffer(storage=LazyTensorStorage(max_size=1000),
    ...                   sampler=SliceSamplerWithoutReplacement(
    ...                       slice_len=5, traj_key="episode",strict_length=False
    ...                   ))
    ...
    >>> ep_1 = TensorDict(
    ...     {"obs": torch.arange(100),
    ...     "episode": torch.zeros(100),},
    ...     batch_size=[100]
    ... )
    >>> ep_2 = TensorDict(
    ...     {"obs": torch.arange(4),
    ...     "episode": torch.ones(4),},
    ...     batch_size=[4]
    ... )
    >>> rb.extend(ep_1)
    >>> rb.extend(ep_2)
    >>>
    >>> s = rb.sample(50)
    >>> t = split_trajectories(s, trajectory_key="episode")
    >>> print(t["obs"])
    tensor([[75, 76, 77, 78, 79],
            [ 0,  1,  2,  3,  0]])
    >>> print(t["episode"])
    tensor([[0., 0., 0., 0., 0.],
            [1., 1., 1., 1., 0.]])
    

示例

>>> import torch
>>> from tensordict import TensorDict
>>> from torchrl.data.replay_buffers import LazyMemmapStorage, TensorDictReplayBuffer
>>> from torchrl.data.replay_buffers.samplers import SliceSampler
>>> torch.manual_seed(0)
>>> rb = TensorDictReplayBuffer(
...     storage=LazyMemmapStorage(1_000_000),
...     sampler=SliceSampler(cache_values=True, num_slices=10),
...     batch_size=320,
... )
>>> episode = torch.zeros(1000, dtype=torch.int)
>>> episode[:300] = 1
>>> episode[300:550] = 2
>>> episode[550:700] = 3
>>> episode[700:] = 4
>>> data = TensorDict(
...     {
...         "episode": episode,
...         "obs": torch.randn((3, 4, 5)).expand(1000, 3, 4, 5),
...         "act": torch.randn((20,)).expand(1000, 20),
...         "other": torch.randn((20, 50)).expand(1000, 20, 50),
...     }, [1000]
... )
>>> rb.extend(data)
>>> sample = rb.sample()
>>> print("sample:", sample)
>>> print("episodes", sample.get("episode").unique())
episodes tensor([1, 2, 3, 4], dtype=torch.int32)

SliceSampler 與大多數 TorchRL 資料集預設相容

示例

>>> import torch
>>>
>>> from torchrl.data.datasets import RobosetExperienceReplay
>>> from torchrl.data import SliceSampler
>>>
>>> torch.manual_seed(0)
>>> num_slices = 10
>>> dataid = list(RobosetExperienceReplay.available_datasets)[0]
>>> data = RobosetExperienceReplay(dataid, batch_size=320, sampler=SliceSampler(num_slices=num_slices))
>>> for batch in data:
...     batch = batch.reshape(num_slices, -1)
...     break
>>> print("check that each batch only has one episode:", batch["episode"].unique(dim=1))
check that each batch only has one episode: tensor([[19],
        [14],
        [ 8],
        [10],
        [13],
        [ 4],
        [ 2],
        [ 3],
        [22],
        [ 8]])

文件

訪問 PyTorch 的全面開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並解答您的疑問

檢視資源