快捷方式

DuelingCnnDQNet

class torchrl.modules.DuelingCnnDQNet(out_features: int, out_features_value: int = 1, cnn_kwargs: dict | None = None, mlp_kwargs: dict | None = None, device: DEVICE_TYPING | None = None)[source]

Dueling CNN Q-網路。

介紹於 https://arxiv.org/abs/1511.06581

引數:
  • out_features (int) – 優勢網路 (advantage network) 的特徵數量。

  • out_features_value (int) – 值網路 (value network) 的特徵數量。

  • cnn_kwargs (dictdict 列表, 可選) –

    特徵網路的 kwargs。預設為

    >>> cnn_kwargs = {
    ...     'num_cells': [32, 64, 64],
    ...     'strides': [4, 2, 1],
    ...     'kernel_sizes': [8, 4, 3],
    ... }
    

  • mlp_kwargs (dictdict 列表, 可選) –

    優勢網路和值網路的 kwargs。預設為

    >>> mlp_kwargs = {
    ...     "depth": 1,
    ...     "activation_class": nn.ELU,
    ...     "num_cells": 512,
    ...     "bias_last_layer": True,
    ... }
    

  • device (torch.device, 可選) – 建立模組的裝置。

示例

>>> import torch
>>> from torchrl.modules import DuelingCnnDQNet
>>> net = DuelingCnnDQNet(out_features=20)
>>> print(net)
DuelingCnnDQNet(
  (features): ConvNet(
    (0): LazyConv2d(0, 32, kernel_size=(8, 8), stride=(4, 4))
    (1): ELU(alpha=1.0)
    (2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))
    (3): ELU(alpha=1.0)
    (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
    (5): ELU(alpha=1.0)
    (6): SquashDims()
  )
  (advantage): MLP(
    (0): LazyLinear(in_features=0, out_features=512, bias=True)
    (1): ELU(alpha=1.0)
    (2): Linear(in_features=512, out_features=20, bias=True)
  )
  (value): MLP(
    (0): LazyLinear(in_features=0, out_features=512, bias=True)
    (1): ELU(alpha=1.0)
    (2): Linear(in_features=512, out_features=1, bias=True)
  )
)
>>> x = torch.zeros(1, 3, 64, 64)
>>> y = net(x)
>>> print(y.shape)
torch.Size([1, 20])
forward(x: Tensor) Tensor[source]

定義每次呼叫時執行的計算。

應被所有子類覆蓋 (Override)。

注意

儘管正向傳播 (forward pass) 的實現需要在本函式中定義,但之後應該呼叫 Module 例項而不是本函式,因為前者會負責執行已註冊的鉤子 (hook),而後者則會靜默忽略它們。

文件

訪問 PyTorch 的全面開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深入教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源