快捷方式

QValueModule

class torchrl.modules.tensordict_module.QValueModule(*args, **kwargs)[源]

用於 Q 值策略的 Q 值 TensorDictModule。

此模組根據給定的動作空間(one-hot、binary 或 categorical),將包含動作值的張量處理為其 argmax 分量(即結果的貪婪動作)。它適用於 tensordict 和普通張量。

引數:
  • action_space (str, 可選) – 動作空間。必須是 "one-hot", "mult-one-hot", "binary""categorical" 之一。此引數與 spec 互斥,因為 spec 決定了動作空間。

  • action_value_key (strtuple of str, 可選) – 表示動作值的輸入鍵。預設為 "action_value"

  • action_mask_key (strtuple of str, 可選) – 表示動作掩碼的輸入鍵。預設為 "None"(等同於無掩碼)。

  • out_keys (list of strtuple of str, 可選) – 表示動作、動作值和所選動作值的輸出鍵。預設為 ["action", "action_value", "chosen_action_value"]

  • var_nums (int, 可選) – 如果 action_space = "mult-one-hot",此值表示每個動作分量的基數。

  • spec (TensorSpec, 可選) – 如果提供,則表示動作(以及/或其它輸出)的規格。此引數與 action_space 互斥,因為 spec 決定了動作空間。

  • safe (bool) – 如果為 True,則檢查輸出值是否符合輸入規格。由於探索策略或數值下溢/溢位問題,可能會發生超出範圍的取樣。如果此值超出範圍,將使用 TensorSpec.project 方法將其投影回所需的空間。預設為 False

返回:

如果輸入是單個張量,則返回一個包含所選動作、值和所選動作值的三個元素組。如果提供了 tensordict,則使用 out_keys 欄位指示的鍵在其中更新這些條目。

示例

>>> from tensordict import TensorDict
>>> action_space = "categorical"
>>> action_value_key = "my_action_value"
>>> actor = QValueModule(action_space, action_value_key=action_value_key)
>>> # This module works with both tensordict and regular tensors:
>>> value = torch.zeros(4)
>>> value[-1] = 1
>>> actor(my_action_value=value)
(tensor(3), tensor([0., 0., 0., 1.]), tensor([1.]))
>>> actor(value)
(tensor(3), tensor([0., 0., 0., 1.]), tensor([1.]))
>>> actor(TensorDict({action_value_key: value}, []))
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
        action_value: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False),
        chosen_action_value: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
        my_action_value: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
forward(tensordict: Tensor = None) TensorDictBase[源]

定義每次呼叫時執行的計算。

應被所有子類覆蓋。

注意

儘管前向傳播 (forward pass) 的實現需要在該函式內定義,但後續應呼叫 Module 例項而不是此函式本身,因為前者負責執行已註冊的鉤子,而後者會靜默忽略它們。


© 版權所有 2022, Meta。

使用 Sphinx 構建,主題由 Read the Docs 提供。

文件

查閱 PyTorch 的全面開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深入教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源