LazyTensorStorage¶
- class torchrl.data.replay_buffers.LazyTensorStorage(max_size: int, *, device: device = 'cpu', ndim: int = 1, compilable: bool = False, consolidated: bool = False)[source]¶
一個用於張量(tensor)和張量字典(tensordict)的預分配張量儲存。
- 引數:
max_size (int) – 儲存的大小,即緩衝區中儲存的最大元素數量。
- 關鍵字引數:
device (torch.device, optional) – 取樣張量儲存和傳送到的裝置。預設為
torch.device("cpu")。如果傳入“auto”,裝置將自動從傳入的第一批資料中獲取。預設不啟用此功能,以避免錯誤地將資料放置在 GPU 上,從而導致 OOM(記憶體不足)問題。ndim (int, optional) – 測量儲存大小時需要考慮的維度數量。例如,形狀為
[3, 4]的儲存,如果ndim=1,其容量為3;如果ndim=2,其容量為12。預設為1。compilable (bool, optional) – 儲存是否可編譯。如果為
True,寫入器不能在多個程序之間共享。預設為False。consolidated (bool, optional) – 如果為
True,儲存將在首次擴充套件後被整合。預設為False。
示例
>>> data = TensorDict({ ... "some data": torch.randn(10, 11), ... ("some", "nested", "data"): torch.randn(10, 11, 12), ... }, batch_size=[10, 11]) >>> storage = LazyTensorStorage(100) >>> storage.set(range(10), data) >>> len(storage) # only the first dimension is considered as indexable 10 >>> storage.get(0) TensorDict( fields={ some data: Tensor(shape=torch.Size([11]), device=cpu, dtype=torch.float32, is_shared=False), some: TensorDict( fields={ nested: TensorDict( fields={ data: Tensor(shape=torch.Size([11, 12]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([11]), device=cpu, is_shared=False)}, batch_size=torch.Size([11]), device=cpu, is_shared=False)}, batch_size=torch.Size([11]), device=cpu, is_shared=False) >>> storage.set(0, storage.get(0).zero_()) # zeros the data along index ``0``
此類也支援 tensorclass 資料。
示例
>>> from tensordict import tensorclass >>> @tensorclass ... class MyClass: ... foo: torch.Tensor ... bar: torch.Tensor >>> data = MyClass(foo=torch.randn(10, 11), bar=torch.randn(10, 11, 12), batch_size=[10, 11]) >>> storage = LazyTensorStorage(10) >>> storage.set(range(10), data) >>> storage.get(0) MyClass( bar=Tensor(shape=torch.Size([11, 12]), device=cpu, dtype=torch.float32, is_shared=False), foo=Tensor(shape=torch.Size([11]), device=cpu, dtype=torch.float32, is_shared=False), batch_size=torch.Size([11]), device=cpu, is_shared=False)
- attach(buffer: Any) None¶
此函式將取樣器附加到此儲存。
從該儲存讀取的緩衝區必須透過呼叫此方法作為附加實體包含進來。這確保了當儲存中的資料發生變化時,即使儲存與其他緩衝區(例如,優先順序取樣器)共享,元件也能感知到變化。
- 引數:
buffer – 從此儲存讀取的物件。
- dump(*args, **kwargs)¶
dumps()的別名。
- load(*args, **kwargs)¶
loads()的別名。
- save(*args, **kwargs)¶
dumps()的別名。