GAILLoss¶
- 類 torchrl.objectives.GAILLoss(*args, **kwargs)[源]¶
TorchRL 實現的生成對抗模仿學習 (GAIL) 損失函式。
發表於 “Generative Adversarial Imitation Learning” <https://arxiv.org/pdf/1606.03476>
- 引數:
discriminator_network (TensorDictModule) – 隨機 actor
- 關鍵字引數:
use_grad_penalty (bool, 可選) – 是否使用梯度懲罰。預設值:
False。gp_lambda (
float, 可選) – 梯度懲罰 lambda。預設值:10。reduction (str, 可選) – 指定要應用於輸出的歸約方式:
"none"|"mean"|"sum"。"none": 不應用歸約,"mean": 輸出的總和將除以輸出中的元素數量,"sum": 對輸出求和。預設值:"mean"。
- default_keys¶
_AcceptedKeys 的別名