快捷方式

GAILLoss

torchrl.objectives.GAILLoss(*args, **kwargs)[源]

TorchRL 實現的生成對抗模仿學習 (GAIL) 損失函式。

發表於 “Generative Adversarial Imitation Learning” <https://arxiv.org/pdf/1606.03476>

引數

discriminator_network (TensorDictModule) – 隨機 actor

關鍵字引數
  • use_grad_penalty (bool, 可選) – 是否使用梯度懲罰。預設值:False

  • gp_lambda (float, 可選) – 梯度懲罰 lambda。預設值:10

  • reduction (str, 可選) – 指定要應用於輸出的歸約方式:"none" | "mean" | "sum""none": 不應用歸約,"mean": 輸出的總和將除以輸出中的元素數量,"sum": 對輸出求和。預設值:"mean"

default_keys

_AcceptedKeys 的別名

forward(tensordict: TensorDictBase = None) TensorDictBase[源]

forward 方法。

計算判別器損失,如果 use_grad_penalty 設定為 True,則計算梯度懲罰。如果 use_grad_penalty 設定為 True,還會返回分離的梯度懲罰損失,用於日誌記錄。要檢視輸入 tensordict 中期望的鍵和輸出中期望的鍵,請檢查類的 “in_keys”“out_keys” 屬性。

文件

訪問 PyTorch 的完整開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源