ReplayBufferTrainer¶
- class torchrl.trainers.ReplayBufferTrainer(replay_buffer: TensorDictReplayBuffer, batch_size: Optional[int] = None, memmap: bool = False, device: DEVICE_TYPING | None = None, flatten_tensordicts: bool = False, max_dims: Optional[Sequence[int]] = None)[source]¶
回放緩衝區鉤子提供者。
- 引數:
replay_buffer (TensorDictReplayBuffer) – 要使用的回放緩衝區。
batch_size (int, optional) – 從最新採集或回放緩衝區中取樣資料時的批次大小。如果未提供,將使用回放緩衝區的批次大小(對於批次大小不變的情況,這是首選選項)。
memmap (bool, optional) – 如果為
True,則建立一個 memmap tensordict。預設為False。device (device, optional) – 必須放置樣本的裝置。預設為
None。flatten_tensordicts (bool, optional) – 如果為
True,tensordict 在傳遞給回放緩衝區之前將被展平(或等效地使用從收集器獲得的有效掩碼進行掩碼)。否則,除了填充(見下文max_dims引數)之外,不會執行其他變換。預設為False。max_dims (int 序列, optional) – 如果
flatten_tensordicts設定為 False,這將是一個列表,其長度等於提供的 tensordict 的批次大小,表示每個 tensordict 的最大尺寸。如果提供,此尺寸列表將用於填充 tensordict,並在將其傳遞給回放緩衝區之前使其形狀匹配。如果沒有最大值,應提供 -1。
示例
>>> rb_trainer = ReplayBufferTrainer(replay_buffer=replay_buffer, batch_size=N) >>> trainer.register_op("batch_process", rb_trainer.extend) >>> trainer.register_op("process_optim_batch", rb_trainer.sample) >>> trainer.register_op("post_loss", rb_trainer.update_priority)