ConsistentDropout¶
- class torchrl.modules.ConsistentDropout(p: float = 0.5)[source]¶
實現具有一致性 dropout 的
Dropout變體。該方法在 “Consistent Dropout for Policy Gradient Reinforcement Learning” (Hausknecht & Wagener, 2022) 中提出。
該
Dropout變體透過快取 rollout 期間使用的 dropout 掩碼並在更新階段重用它們,試圖增加訓練穩定性並減少更新方差。您正在檢視的此類獨立於 TorchRL API 的其餘部分,並且不需要 tensordict 即可執行。
ConsistentDropoutModule是ConsistentDropout的一個包裝器,它利用TensorDict的可擴充套件性,將生成的 dropout 掩碼儲存在 transitionTensorDict本身中。有關詳細說明和使用示例,請參閱此類。除此之外,與 PyTorch
Dropout實現相比,概念上沒有太大偏差。- ..注意:: TorchRL 的資料收集器在
no_grad()模式下執行 rollout,但不在 eval 模式下執行, 因此除非傳遞給收集器的策略處於 eval 模式,否則 dropout 掩碼將被應用。
注意
與其他探索模組不同,
ConsistentDropoutModule使用train/eval模式以符合 PyTorch 中常規的 Dropout API。set_exploration_type()上下文管理器對此模組無效。- 引數:
p (
float, 可選) – Dropout 機率。預設為0.5。
另請參閱
MultiSyncDataCollector: 內部使用_main_async_collector()(SyncDataCollector)
- forward(x: torch.Tensor, mask: torch.Tensor | None = None) torch.Tensor[source]¶
在訓練期間 (rollout & 更新),此呼叫在與輸入張量相乘之前,會遮蔽一個全為 1 的張量。
在評估期間,此呼叫結果為無操作,僅返回輸入。
- 引數:
x (torch.Tensor) – 輸入張量。
mask (torch.Tensor, 可選) – 用於 dropout 的可選掩碼。
返回:在訓練模式下返回一個張量和對應的掩碼,在評估模式下僅返回一個張量。
- ..注意:: TorchRL 的資料收集器在