快捷方式

DistributionalDQNnet

class torchrl.modules.DistributionalDQNnet(*args, **kwargs)[source]

分佈深度 Q 網路 Softmax 層。

此層應位於預測動作值的常規模型與作用於 logits 值的分佈之間使用。

引數:
  • in_keys (str 列表或 str 元組) – log-softmax 操作的輸入鍵。預設為 ["action_value"]

  • out_keys (str 列表或 str 元組) – log-softmax 操作的輸出鍵。預設為 ["action_value"]

示例

>>> import torch
>>> from tensordict import TensorDict
>>> net = DistributionalDQNnet()
>>> td = TensorDict({"action_value": torch.randn(10, 5)}, batch_size=[10])
>>> net(td)
TensorDict(
    fields={
        action_value: Tensor(shape=torch.Size([10, 5]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([10]),
    device=None,
    is_shared=False)
forward(tensordict=None)[source]

定義每次呼叫時執行的計算。

應被所有子類覆蓋。

注意

儘管前向傳播的實現需要在該函式中定義,但之後應該呼叫 Module 例項而不是該函式,因為前者負責執行註冊的鉤子,而後者會默默地忽略它們。

文件

訪問 PyTorch 的完整開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲取問題解答

檢視資源