• 文件 >
  • 資料收集和儲存入門
快捷方式

資料收集和儲存入門

作者: Vincent Moens

注意

要在 notebook 中執行本教程,請在開頭新增一個安裝單元格,其中包含

!pip install tensordict
!pip install torchrl
import tempfile

沒有資料就沒有學習。在監督學習中,使用者習慣於使用 DataLoader 等工具將資料整合到訓練迴圈中。Dataloader 是可迭代物件,它們為你提供用於訓練模型的資料。

TorchRL 以類似的方式處理資料載入問題,儘管這在強化學習庫生態系統中是出乎意料的獨特之處。TorchRL 的資料載入器被稱為 DataCollectors。大多數時候,資料收集並不僅僅是原始資料的收集,因為在被損失模組消耗之前,資料需要暫時儲存在緩衝區(或用於 on-policy 演算法的等效結構)中。本教程將探討這兩個類。

資料收集器

此處討論的主要資料收集器是 SyncDataCollector,這是本文件的重點。在基礎層面,收集器是一個簡單的類,負責在環境中執行你的策略,必要時重置環境,並提供預定義大小的批次資料。與環境教程中演示的 rollout() 方法不同,收集器在連續的資料批次之間不會重置。因此,連續的兩個資料批次可能包含來自同一軌跡的元素。

你需要傳遞給收集器的基本引數是你想收集的批次大小(frames_per_batch),迭代器的長度(可能無限),策略和環境。為簡單起見,在此示例中我們將使用一個虛擬的隨機策略。

import torch

torch.manual_seed(0)

from torchrl.collectors import SyncDataCollector
from torchrl.envs import GymEnv
from torchrl.envs.utils import RandomPolicy

env = GymEnv("CartPole-v1")
env.set_seed(0)

policy = RandomPolicy(env.action_spec)
collector = SyncDataCollector(env, policy, frames_per_batch=200, total_frames=-1)

現在我們期望收集器無論在收集過程中發生什麼,都將交付大小為 200 的批次資料。換句話說,這個批次中可能包含多個軌跡!total_frames 指示收集器應該執行多長時間。值為 -1 將生成一個永不停止的收集器。

讓我們迭代收集器,瞭解這些資料看起來像什麼

for data in collector:
    print(data)
    break

如你所見,我們的資料增加了一些收集器特有的元資料,這些元資料被分組到一個 "collector"tensordict 中,這是我們在環境 rollout 期間沒有看到的。這對於跟蹤軌跡 ID 很有用。在下面的列表中,每個條目標記了相應 transition 所屬的軌跡編號

print(data["collector", "traj_ids"])

在編寫最先進的演算法時,資料收集器非常有用,因為效能通常是透過特定技術在給定數量的環境互動次數(收集器中的 total_frames 引數)內解決問題的能力來衡量的。因此,我們示例中的大多數訓練迴圈都像這樣

..code - block::Python

>>> for data in collector:
...     # your algorithm here

回放緩衝區

既然我們已經探索瞭如何收集資料,我們想知道如何儲存它。在強化學習中,典型的設定是收集資料,臨時儲存,並在一段時間後根據某種啟發式方法清除:先入先出或其他。一個典型的虛擬碼看起來像這樣

..code - block::Python

>>> for data in collector:
...     storage.store(data)
...     for i in range(n_optim):
...         sample = storage.sample()
...         loss_val = loss_fn(sample)
...         loss_val.backward()
...         optim.step() # etc

TorchRL 中儲存資料的父類被稱為 ReplayBuffer。TorchRL 的回放緩衝區是可組合的:你可以編輯儲存型別、它們的取樣技術、寫入啟發式方法或應用於它們的 transforms。我們將把更高階的內容留給專門的深入教程。通用的回放緩衝區只需要知道它要使用什麼儲存。通常,我們推薦使用 TensorStorage 子類,這在大多數情況下都能很好地工作。在本教程中,我們將使用 LazyMemmapStorage,它具有兩個很好的特性:首先,它很“lazy”,你無需提前明確告知它你的資料是什麼樣子。其次,它使用 MemoryMappedTensor 作為後端,以高效的方式將你的資料儲存在磁碟上。你唯一需要知道的是你希望緩衝區有多大。

from torchrl.data.replay_buffers import LazyMemmapStorage, ReplayBuffer

buffer_scratch_dir = tempfile.TemporaryDirectory().name

buffer = ReplayBuffer(
    storage=LazyMemmapStorage(max_size=1000, scratch_dir=buffer_scratch_dir)
)

填充緩衝區可以透過 add()(單個元素)或 extend()(多個元素)方法完成。使用我們剛剛收集的資料,我們可以一步初始化並填充緩衝區

indices = buffer.extend(data)

我們可以檢查緩衝區現在具有與我們從收集器中獲得的資料相同的元素數量

assert len(buffer) == collector.frames_per_batch

唯一需要知道的是如何從緩衝區中獲取資料。自然地,這依賴於 sample() 方法。由於我們沒有指定取樣必須是不重複的,因此不能保證從緩衝區獲取的樣本是唯一的

sample = buffer.sample(batch_size=30)
print(sample)

再次,我們的樣本看起來與我們從收集器收集的資料完全相同!

下一步

  • 你可以檢視其他多程序收集器,例如 MultiSyncDataCollectorMultiaSyncDataCollector

  • 如果你有多個節點用於推理,TorchRL 還提供了分散式收集器。請在API 參考中檢視它們。

  • 查閱專門的回放緩衝區教程,瞭解構建緩衝區時的更多選項,或查閱API 參考,其中涵蓋所有詳細功能。回放緩衝區具有無數功能,例如多執行緒取樣、優先經驗回放等等……

  • 為簡單起見,我們沒有介紹回放緩衝區的迭代能力。你可以自己嘗試一下:構建一個緩衝區並在建構函式中指定其 batch-size,然後嘗試對其進行迭代。這相當於在迴圈中呼叫 rb.sample()

由 Sphinx-Gallery 生成的相簿

文件

獲取 PyTorch 全面的開發者文件

檢視文件

教程

獲取針對初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源