快捷方式

ConsistentDropout

class torchrl.modules.ConsistentDropout(p: float = 0.5)[source]

實現具有一致性 dropout 的 Dropout 變體。

該方法在 “Consistent Dropout for Policy Gradient Reinforcement Learning” (Hausknecht & Wagener, 2022) 中提出。

Dropout 變體透過快取 rollout 期間使用的 dropout 掩碼並在更新階段重用它們,試圖增加訓練穩定性並減少更新方差。

您正在檢視的此類獨立於 TorchRL API 的其餘部分,並且不需要 tensordict 即可執行。ConsistentDropoutModuleConsistentDropout 的一個包裝器,它利用 TensorDict 的可擴充套件性,將生成的 dropout 掩碼儲存在 transition TensorDict 本身中。有關詳細說明和使用示例,請參閱此類。

除此之外,與 PyTorch Dropout 實現相比,概念上沒有太大偏差。

..注意:: TorchRL 的資料收集器在 no_grad() 模式下執行 rollout,但不在 eval 模式下執行,

因此除非傳遞給收集器的策略處於 eval 模式,否則 dropout 掩碼將被應用。

注意

與其他探索模組不同,ConsistentDropoutModule 使用 train/eval 模式以符合 PyTorch 中常規的 Dropout API。set_exploration_type() 上下文管理器對此模組無效。

引數:

p (float, 可選) – Dropout 機率。預設為 0.5

另請參閱

forward(x: torch.Tensor, mask: torch.Tensor | None = None) torch.Tensor[source]

在訓練期間 (rollout & 更新),此呼叫在與輸入張量相乘之前,會遮蔽一個全為 1 的張量。

在評估期間,此呼叫結果為無操作,僅返回輸入。

引數:

返回:在訓練模式下返回一個張量和對應的掩碼,在評估模式下僅返回一個張量。

文件

查閱 PyTorch 全面開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源