• 文件 >
  • 使用回放緩衝區
快捷方式

使用回放緩衝區

作者: Vincent Moens

回放緩衝區是任何 RL 或控制演算法的核心組成部分。監督學習方法通常透過一個訓練迴圈來表徵,其中資料從靜態資料集中隨機提取並依次饋送到模型和損失函式。在 RL 中,情況通常略有不同:資料使用模型收集,然後臨時儲存在動態結構(經驗回放緩衝區)中,該結構用作損失模組的資料集。

一如既往,緩衝區的用途極大地影響了它的構建方式:有些人可能希望儲存軌跡,而另一些人則希望儲存單個轉換。特定的取樣策略可能在某些上下文中更受歡迎:某些專案可能比其他專案具有更高的優先順序,或者有放回或無放回抽樣可能很重要。計算因素也可能發揮作用,例如緩衝區的大小可能超出可用的 RAM 儲存空間。

由於這些原因,TorchRL 的回放緩衝區是完全可組合的:雖然它們自帶“電池”,只需最少的努力即可構建,但它們也支援許多自定義,例如儲存型別、取樣策略或資料 transforms。

在本教程中,你將學習

基礎知識:構建一個普通回放緩衝區

TorchRL 的回放緩衝區旨在優先考慮模組化、可組合性、效率和簡單性。例如,建立一個基本的回放緩衝區是一個簡單的過程,如下例所示

import tempfile

from torchrl.data import ReplayBuffer

buffer = ReplayBuffer()

預設情況下,此回放緩衝區的大小為 1000。我們透過使用 extend() 方法填充緩衝區來檢查這一點

print("length before adding elements:", len(buffer))

buffer.extend(range(2000))

print("length after adding elements:", len(buffer))

我們使用了旨在一次新增多個專案的 extend() 方法。如果傳遞給 extend 的物件具有多個維度,則其第一個維度將被視為要在緩衝區中拆分成單獨元素的部分。

這實質上意味著,當將多維 tensors 或 tensordicts 新增到緩衝區時,緩衝區在計算其記憶體中儲存的元素時,只會檢視第一個維度。如果傳遞的物件不可迭代,則會丟擲異常。

要逐個新增專案,應改用 add() 方法。

自定義儲存

我們看到緩衝區已被限制為我們傳遞給它的前 1000 個元素。要更改大小,我們需要自定義我們的儲存。

TorchRL 提供三種類型的儲存

  • ListStorage 將元素獨立儲存在列表中。它支援任何資料型別,但這種靈活性是以效率為代價的;

  • LazyTensorStorage 連續儲存 tensors 資料結構。它自然地與 TensorDict(或 tensorclass)物件一起工作。儲存是按 tensor 連續的,這意味著取樣將比使用列表時更高效,但隱含的限制是傳遞給它的任何資料必須與用於例項化緩衝區的第一個資料批次具有相同基本屬性(如 shape 和 dtype)。傳遞不符合此要求的資料將引發異常或導致某些未定義的行為。

  • LazyMemmapStorage 的工作方式與 LazyTensorStorage 類似,它也是 lazy 的(即,它期望例項化第一個資料批次),並且它要求儲存的每個批次資料在 shape 和 dtype 上匹配。使這種儲存獨特之處在於它指向磁碟檔案(或使用檔案系統儲存),這意味著它可以支援非常大的資料集,同時仍以連續的方式訪問資料。

讓我們看看如何使用這些儲存中的每一種

from torchrl.data import LazyMemmapStorage, LazyTensorStorage, ListStorage

# We define the maximum size of the buffer
size = 100

帶有列表儲存的緩衝區可以儲存任何型別的資料(但我們必須更改 collate_fn,因為預設期望數值資料)

buffer_list = ReplayBuffer(storage=ListStorage(size), collate_fn=lambda x: x)
buffer_list.extend(["a", 0, "b"])
print(buffer_list.sample(3))

由於它具有最少量的假設,ListStorage 是 TorchRL 中的預設儲存。

LazyTensorStorage 可以連續儲存資料。在處理複雜但不變的中等大小資料結構時,這應該是首選選項

buffer_lazytensor = ReplayBuffer(storage=LazyTensorStorage(size))

讓我們建立一個大小為 torch.Size([3]) 的資料批次,其中儲存了 2 個 tensors

import torch
from tensordict import TensorDict

data = TensorDict(
    {
        "a": torch.arange(12).view(3, 4),
        ("b", "c"): torch.arange(15).view(3, 5),
    },
    batch_size=[3],
)
print(data)

第一次呼叫 extend() 將例項化儲存。資料的第一維度被解開為單獨的資料點

buffer_lazytensor.extend(data)
print(f"The buffer has {len(buffer_lazytensor)} elements")

讓我們從緩衝區中取樣,並列印資料

sample = buffer_lazytensor.sample(5)
print("samples", sample["a"], sample["b", "c"])

LazyMemmapStorage 也以同樣的方式建立。我們還可以自定義磁碟上的儲存位置

with tempfile.TemporaryDirectory() as tempdir:
    buffer_lazymemmap = ReplayBuffer(
        storage=LazyMemmapStorage(size, scratch_dir=tempdir)
    )
    buffer_lazymemmap.extend(data)
    print(f"The buffer has {len(buffer_lazymemmap)} elements")
    print(
        "the 'a' tensor is stored in", buffer_lazymemmap._storage._storage["a"].filename
    )
    print(
        "the ('b', 'c') tensor is stored in",
        buffer_lazymemmap._storage._storage["b", "c"].filename,
    )
    sample = buffer_lazytensor.sample(5)
    print("samples: a=", sample["a"], "\n('b', 'c'):", sample["b", "c"])
    del buffer_lazymemmap

與 TensorDict 整合

tensor 位置遵循包含它們的 TensorDict 的相同結構:這使得在訓練期間輕鬆儲存和載入緩衝區成為可能。

為了充分發揮 TensorDict 作為資料載體的潛力,可以使用 TensorDictReplayBuffer 類。它的一個主要優點是能夠處理取樣資料的組織,以及可能需要的任何附加資訊(例如樣本索引)。

它可以像標準的 ReplayBuffer 一樣構建,並且通常可以互換使用。

from torchrl.data import TensorDictReplayBuffer

with tempfile.TemporaryDirectory() as tempdir:
    buffer_lazymemmap = TensorDictReplayBuffer(
        storage=LazyMemmapStorage(size, scratch_dir=tempdir), batch_size=12
    )
    buffer_lazymemmap.extend(data)
    print(f"The buffer has {len(buffer_lazymemmap)} elements")
    sample = buffer_lazymemmap.sample()
    print("sample:", sample)
    del buffer_lazymemmap

我們的樣本現在有一個額外的 "index" 鍵,指示取樣了哪些索引。讓我們看看這些索引

print(sample["index"])

與 tensorclass 整合

ReplayBuffer 類及其相關的子類也原生支援 tensorclass 類,這些類可以方便地用於以更明確的方式編碼資料集

from tensordict import tensorclass


@tensorclass
class MyData:
    images: torch.Tensor
    labels: torch.Tensor


data = MyData(
    images=torch.randint(
        255,
        (10, 64, 64, 3),
    ),
    labels=torch.randint(100, (10,)),
    batch_size=[10],
)

buffer_lazy = ReplayBuffer(storage=LazyTensorStorage(size), batch_size=12)
buffer_lazy.extend(data)
print(f"The buffer has {len(buffer_lazy)} elements")
sample = buffer_lazy.sample()
print("sample:", sample)

正如所料,資料具有正確的類和 shape!

與其他 tensor 結構(PyTrees)整合

TorchRL 的回放緩衝區也支援任何 pytree 資料結構。一個 PyTree 是由 dicts、lists 和/或 tuples 組成的任意深度的巢狀結構,其中葉子是 tensors。這意味著可以在連續記憶體中儲存任何此類樹結構!可以使用各種儲存:TensorStorageLazyMemmapStorageLazyTensorStorage 都接受這種資料。

這裡是對此功能的簡要演示

from torch.utils._pytree import tree_map

讓我們在 RAM 上構建回放緩衝區

rb = ReplayBuffer(storage=LazyTensorStorage(size))
data = {
    "a": torch.randn(3),
    "b": {"c": (torch.zeros(2), [torch.ones(1)])},
    30: -torch.ones(()),  # non-string keys also work
}
rb.add(data)

# The sample has a similar structure to the data (with a leading dimension of 10 for each tensor)
sample = rb.sample(10)

對於 pytrees,任何可呼叫物件都可以用作 transform

def transform(x):
    # Zeros all the data in the pytree
    return tree_map(lambda y: y * 0, x)


rb.append_transform(transform)
sample = rb.sample(batch_size=12)

讓我們檢查一下我們的 transform 是否正常工作

def assert0(x):
    assert (x == 0).all()


tree_map(assert0, sample)

取樣和遍歷緩衝區

回放緩衝區支援多種取樣策略

  • 如果 batch-size 是固定的並且可以在構建時定義,則可以將其作為關鍵字引數傳遞給緩衝區;

  • 使用固定的 batch-size,可以遍歷回放緩衝區以收集樣本;

  • 如果 batch-size 是動態的,則可以在執行時將其傳遞給 sample 方法。

可以使用多執行緒進行取樣,但這與最後一種選擇不相容(因為它要求緩衝區預先知道下一個批次的大小)。

讓我們看幾個例子

固定 batch-size

如果在構建期間傳遞了 batch-size,則在取樣時應省略它

data = MyData(
    images=torch.randint(
        255,
        (200, 64, 64, 3),
    ),
    labels=torch.randint(100, (200,)),
    batch_size=[200],
)

buffer_lazy = ReplayBuffer(storage=LazyTensorStorage(size), batch_size=128)
buffer_lazy.extend(data)
buffer_lazy.sample()

此資料批次的大小是我們想要的大小 (128)。

要啟用多執行緒取樣,只需在構建期間將正整數傳遞給 prefetch 關鍵字引數。這應顯著加快取樣速度,尤其是在取樣耗時的情況下(例如,使用優先採樣器時)

buffer_lazy = ReplayBuffer(
    storage=LazyTensorStorage(size), batch_size=128, prefetch=10
)  # creates a queue of 10 elements to be prefetched in the background
buffer_lazy.extend(data)
print(buffer_lazy.sample())

以固定 batch-size 遍歷緩衝區

只要 batch-size 是預定義的,我們也可以像使用常規 dataloader 一樣遍歷緩衝區

for i, data in enumerate(buffer_lazy):
    if i == 3:
        print(data)
        break

del buffer_lazy

由於我們的取樣技術是完全隨機的並且不阻止有放回取樣,因此該迭代器是無限的。但是,我們可以改用 SamplerWithoutReplacement(無放回取樣器),它將把我們的緩衝區轉換為一個有限迭代器

from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement

buffer_lazy = ReplayBuffer(
    storage=LazyTensorStorage(size), batch_size=32, sampler=SamplerWithoutReplacement()
)

我們建立一個足夠大的資料來獲取幾個樣本

data = TensorDict(
    {
        "a": torch.arange(64).view(16, 4),
        ("b", "c"): torch.arange(128).view(16, 8),
    },
    batch_size=[16],
)

buffer_lazy.extend(data)
for _i, _ in enumerate(buffer_lazy):
    continue
print(f"A total of {_i+1} batches have been collected")

del buffer_lazy

動態 batch-size

與我們之前看到的不同,batch_size 關鍵字引數可以省略,並直接傳遞給 sample 方法

buffer_lazy = ReplayBuffer(
    storage=LazyTensorStorage(size), sampler=SamplerWithoutReplacement()
)
buffer_lazy.extend(data)
print("sampling 3 elements:", buffer_lazy.sample(3))
print("sampling 5 elements:", buffer_lazy.sample(5))

del buffer_lazy

優先回放緩衝區

TorchRL 還提供了 優先回放緩衝區 的介面。此類緩衝區根據透過資料傳遞的優先順序訊號進行取樣。

雖然此工具相容非 TensorDict 資料,但我們鼓勵改用 TensorDict,因為它使得在緩衝區內外攜帶元資料變得容易。

讓我們首先看看如何在一般情況下構建一個優先回放緩衝區。\(\alpha\) 和 \(\beta\) 超引數必須手動設定

from torchrl.data.replay_buffers.samplers import PrioritizedSampler

size = 100

rb = ReplayBuffer(
    storage=ListStorage(size),
    sampler=PrioritizedSampler(max_capacity=size, alpha=0.8, beta=1.1),
    collate_fn=lambda x: x,
)

擴充套件回放緩衝區會返回專案索引,我們稍後將需要這些索引來更新優先順序

indices = rb.extend([1, "foo", None])

取樣器期望每個元素都有一個優先順序。當新增到緩衝區時,優先順序被設定為預設值 1。優先順序計算後(通常透過損失函式),必須在緩衝區中更新它。

這是透過 update_priority() 方法完成的,該方法需要索引和優先順序。我們將資料集中的第二個樣本分配一個人為的高優先順序,以觀察其對取樣的影響

rb.update_priority(index=indices, priority=torch.tensor([0, 1_000, 0.1]))

我們觀察到從緩衝區取樣的結果主要是第二個樣本("foo"

sample, info = rb.sample(10, return_info=True)
print(sample)

info 包含專案的相對權重以及索引。

print(info)

我們看到,使用優先回放緩衝區與使用常規緩衝區相比,在訓練迴圈中需要一系列額外的步驟

  • 收集資料並擴充套件緩衝區後,必須更新專案的優先順序;

  • 計算損失並從中獲取“優先順序訊號”後,我們必須再次更新緩衝區中專案的優先順序。這需要我們跟蹤索引。

這極大地阻礙了緩衝區的可重用性:如果要編寫一個訓練指令碼,其中既可以建立優先緩衝區也可以建立常規緩衝區,則她必須新增大量的控制流,以確保僅當使用優先緩衝區時,才在適當的位置呼叫適當的方法。

讓我們看看如何使用 TensorDict 來改進這一點。我們看到 TensorDictReplayBuffer 返回的資料透過其相對儲存索引進行了增強。我們沒有提到的一個特性是,如果優先順序訊號在擴充套件期間存在,此類還會確保將其自動解析到優先採樣器。

這些功能的結合在幾個方面簡化了事情:- 擴充套件緩衝區時,優先順序訊號將自動

如果存在則被解析,並且優先順序將被準確分配;

  • 索引將儲存在取樣的 tensordicts 中,使得在損失計算後易於更新優先順序。

  • 計算損失時,優先順序訊號將註冊到傳遞給損失模組的 tensordict 中,使得無需努力即可更新權重

    ..code - block::Python

    >>> data = replay_buffer.sample()
    >>> loss_val = loss_module(data)
    >>> replay_buffer.update_tensordict_priority(data)
    

以下程式碼闡述了這些概念。我們構建了一個帶有優先採樣器的回放緩衝區,並在建構函式中指明瞭應該獲取優先順序訊號的入口

rb = TensorDictReplayBuffer(
    storage=ListStorage(size),
    sampler=PrioritizedSampler(size, alpha=0.8, beta=1.1),
    priority_key="td_error",
    batch_size=1024,
)

讓我們選擇一個與儲存索引成比例的優先順序訊號

data["td_error"] = torch.arange(data.numel())

rb.extend(data)

sample = rb.sample()

較高的索引應該更頻繁地出現

from matplotlib import pyplot as plt

fig = plt.hist(sample["index"].numpy())
plt.show()

處理完樣本後,我們使用 torchrl.data.TensorDictReplayBuffer.update_tensordict_priority() 方法更新優先順序鍵。為了演示其工作原理,我們反轉取樣專案的優先順序

sample = rb.sample()
sample["td_error"] = data.numel() - sample["index"]
rb.update_tensordict_priority(sample)

現在,較高的索引應該更少地出現

sample = rb.sample()

fig = plt.hist(sample["index"].numpy())
plt.show()

使用 transforms

儲存在回放緩衝區中的資料可能尚未準備好呈現給損失模組。在某些情況下,collector 生成的資料可能太重而無法按原樣儲存。例如,將影像從 uint8 轉換為浮點 tensors,或在使用 decision transformers 時連線連續的幀。

只需向緩衝區附加適當的 transform,即可處理進出緩衝區的資料。這裡有一些例子

儲存原始影像

uint8 型別的 tensors 比我們通常饋送到模型的浮點 tensors 在記憶體方面便宜得多。因此,儲存原始影像會很有用。以下指令碼展示瞭如何構建一個僅返回原始影像但使用 transformed 影像進行推理的 collector,以及如何將這些 transformations 在回放緩衝區中重複使用

from torchrl.collectors import SyncDataCollector
from torchrl.envs.libs.gym import GymEnv
from torchrl.envs.transforms import (
    Compose,
    GrayScale,
    Resize,
    ToTensorImage,
    TransformedEnv,
)
from torchrl.envs.utils import RandomPolicy

env = TransformedEnv(
    GymEnv("CartPole-v1", from_pixels=True),
    Compose(
        ToTensorImage(in_keys=["pixels"], out_keys=["pixels_trsf"]),
        Resize(in_keys=["pixels_trsf"], w=64, h=64),
        GrayScale(in_keys=["pixels_trsf"]),
    ),
)

讓我們看看一個 rollout

print(env.rollout(3))

我們剛剛建立了一個生成畫素的環境。這些影像經過處理以饋送給策略。我們希望儲存原始影像,而不是它們的 transforms。為此,我們將一個 transform 附加到 collector,以選擇我們希望出現的鍵

from torchrl.envs.transforms import ExcludeTransform

collector = SyncDataCollector(
    env,
    RandomPolicy(env.action_spec),
    frames_per_batch=10,
    total_frames=1000,
    postproc=ExcludeTransform("pixels_trsf", ("next", "pixels_trsf"), "collector"),
)

讓我們看看一批資料,並確認 "pixels_trsf" 鍵已被丟棄

for data in collector:
    print(data)
    break

collector.shutdown()

我們建立一個回放緩衝區,其 transform 與環境相同。然而,有一個細節需要注意:在沒有環境的情況下使用的 transforms 對資料結構一無所知。將 transform 附加到環境時,巢狀 tensordict 中 "next" 的資料首先被轉換,然後在 rollout 執行期間複製到根目錄。使用靜態資料時,情況並非如此。儘管如此,我們的資料帶有一個巢狀的 “next” tensordict,如果我們不明確指示 transform 處理它,它將被忽略。我們手動將這些鍵新增到 transform 中

t = Compose(
    ToTensorImage(
        in_keys=["pixels", ("next", "pixels")],
        out_keys=["pixels_trsf", ("next", "pixels_trsf")],
    ),
    Resize(in_keys=["pixels_trsf", ("next", "pixels_trsf")], w=64, h=64),
    GrayScale(in_keys=["pixels_trsf", ("next", "pixels_trsf")]),
)
rb = TensorDictReplayBuffer(storage=LazyTensorStorage(1000), transform=t, batch_size=16)
rb.extend(data)

我們可以檢查一個 sample 方法是否能看到 transformed 影像重新出現

print(rb.sample())

一個更復雜的示例:使用 CatFrames

CatFrames transform 隨時間展開 observations,建立一個過去事件的 n-back 記憶體,使模型能夠考慮過去事件(在 POMDPs 或使用像 Decision Transformers 這樣的迴圈策略的情況下)。儲存這些連線的幀會消耗大量的記憶體。當訓練和推理期間 n-back 視窗需要不同(通常更長)時,這也可能成為問題。我們透過在兩個階段中分別執行 CatFrames transform 來解決這個問題。

from torchrl.envs import CatFrames, UnsqueezeTransform

我們為返回畫素observations 的環境建立了一個標準的 transforms 列表

env = TransformedEnv(
    GymEnv("CartPole-v1", from_pixels=True),
    Compose(
        ToTensorImage(in_keys=["pixels"], out_keys=["pixels_trsf"]),
        Resize(in_keys=["pixels_trsf"], w=64, h=64),
        GrayScale(in_keys=["pixels_trsf"]),
        UnsqueezeTransform(-4, in_keys=["pixels_trsf"]),
        CatFrames(dim=-4, N=4, in_keys=["pixels_trsf"]),
    ),
)
collector = SyncDataCollector(
    env,
    RandomPolicy(env.action_spec),
    frames_per_batch=10,
    total_frames=1000,
)
for data in collector:
    print(data)
    break

collector.shutdown()

緩衝區 transform 看起來與環境的 transform 非常相似,但像之前一樣帶有額外的 ("next", ...)

t = Compose(
    ToTensorImage(
        in_keys=["pixels", ("next", "pixels")],
        out_keys=["pixels_trsf", ("next", "pixels_trsf")],
    ),
    Resize(in_keys=["pixels_trsf", ("next", "pixels_trsf")], w=64, h=64),
    GrayScale(in_keys=["pixels_trsf", ("next", "pixels_trsf")]),
    UnsqueezeTransform(-4, in_keys=["pixels_trsf", ("next", "pixels_trsf")]),
    CatFrames(dim=-4, N=4, in_keys=["pixels_trsf", ("next", "pixels_trsf")]),
)
rb = TensorDictReplayBuffer(storage=LazyTensorStorage(size), transform=t, batch_size=16)
data_exclude = data.exclude("pixels_trsf", ("next", "pixels_trsf"))
rb.add(data_exclude)

讓我們從緩衝區中取樣一個批次。transformed 畫素鍵的 shape 在倒數第四個維度上應該長度為 4

s = rb.sample(1)  # the buffer has only one element
print(s)

經過一些處理(排除未使用鍵等)後,我們看到線上和離線生成的資料匹配!

assert (data.exclude("collector") == s.squeeze(0).exclude("index", "collector")).all()

儲存軌跡

在許多情況下,最好從緩衝區訪問軌跡而不是簡單的轉換。TorchRL 提供了多種實現方式。

目前,首選方法是將軌跡沿著緩衝區的第一維儲存,並使用 SliceSampler 對這些資料批次進行取樣。此類只需要一些關於你的資料結構的資訊即可完成其工作(注意,截至目前,它僅與 tensordict 結構的資料相容):slices 的數量或其長度,以及關於 episode 之間在哪裡找到分離的資訊(例如,回想一下,使用 DataCollector 時,軌跡 ID 儲存在 ("collector", "traj_ids") 中)。在這個簡單的例子中,我們構建了一個包含 4 個連續短軌跡的資料,並從中取樣了 4 個 slices,每個 slices 的長度為 2(因為 batch size 是 8,並且 8 items // 4 slices = 2 time steps)。我們還標記了 steps。

from torchrl.data import SliceSampler

rb = TensorDictReplayBuffer(
    storage=LazyTensorStorage(size),
    sampler=SliceSampler(traj_key="episode", num_slices=4),
    batch_size=8,
)
episode = torch.zeros(10, dtype=torch.int)
episode[:3] = 1
episode[3:5] = 2
episode[5:7] = 3
episode[7:] = 4
steps = torch.cat([torch.arange(3), torch.arange(2), torch.arange(2), torch.arange(3)])
data = TensorDict(
    {
        "episode": episode,
        "obs": torch.randn((3, 4, 5)).expand(10, 3, 4, 5),
        "act": torch.randn((20,)).expand(10, 20),
        "other": torch.randn((20, 50)).expand(10, 20, 50),
        "steps": steps,
    },
    [10],
)
rb.extend(data)
sample = rb.sample()
print("episode are grouped", sample["episode"])
print("steps are successive", sample["steps"])

gc.collect()

結論

我們已經瞭解瞭如何在 TorchRL 中使用回放緩衝區,從最簡單的用法到需要轉換或以特定方式儲存資料的更高階用法。現在你應該能夠

  • 建立回放緩衝區,自定義其儲存、取樣器和 transforms;

  • 為你的問題選擇最佳儲存型別(list、記憶體或基於磁碟的);

  • 最小化緩衝區的記憶體佔用。

下一步

  • 查閱資料 API 參考,瞭解 TorchRL 中的離線資料集,這些資料集基於我們的回放緩衝區 API;

  • 查閱其他取樣器,例如 SamplerWithoutReplacementPrioritizedSliceSamplerSliceSamplerWithoutReplacement,或檢視其他 writer,例如 TensorDictMaxValueWriter

  • 查閱 文件,瞭解如何檢查點 ReplayBuffers。

由 Sphinx-Gallery 生成的相簿

文件

查閱 PyTorch 全面開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深入教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源