快捷方式

EGreedyModule

class torchrl.modules.EGreedyModule(*args, **kwargs)[原始碼]

Epsilon-Greedy 探索模組。

該模組根據 epsilon-greedy 探索策略隨機更新 tensordict 中的動作。每次呼叫時,會根據某個機率閾值執行隨機抽取(每個動作一次)。如果抽取成功,對應的動作將被替換為從提供的動作規範 (action spec) 中抽取的隨機樣本。未被抽取的動作將保持不變。

引數:
  • spec (TensorSpec) – 用於取樣動作的規範。

  • eps_init (標量, 可選) – 初始 epsilon 值。預設為 1.0

  • eps_end (標量, 可選) – 最終 epsilon 值。預設為 0.1

  • annealing_num_steps (int, 可選) – epsilon 達到 eps_end 值所需的步數。預設為 1000

關鍵字引數:
  • action_key (NestedKey, 可選) – 輸入 tensordict 中動作所在的鍵。預設為 "action"

  • action_mask_key (NestedKey, 可選) – 輸入 tensordict 中動作掩碼所在的鍵。預設為 None(表示沒有掩碼)。

  • device (torch.device, 可選) – 探索模組所在的裝置。

注意

至關重要的是,要在訓練迴圈中呼叫 step() 來更新探索因子。由於很難捕獲這種遺漏,因此如果遺漏了此步驟,不會引發警告或異常!

示例

>>> import torch
>>> from tensordict import TensorDict
>>> from tensordict.nn import TensorDictSequential
>>> from torchrl.modules import EGreedyModule, Actor
>>> from torchrl.data import Bounded
>>> torch.manual_seed(0)
>>> spec = Bounded(-1, 1, torch.Size([4]))
>>> module = torch.nn.Linear(4, 4, bias=False)
>>> policy = Actor(spec=spec, module=module)
>>> explorative_policy = TensorDictSequential(policy,  EGreedyModule(eps_init=0.2))
>>> td = TensorDict({"observation": torch.zeros(10, 4)}, batch_size=[10])
>>> print(explorative_policy(td).get("action"))
tensor([[ 0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.9055, -0.9277, -0.6295, -0.2532],
        [ 0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000]], grad_fn=<AddBackward0>)
forward(tensordict: TensorDictBase) TensorDictBase[原始碼]

定義每次呼叫時執行的計算。

應由所有子類覆蓋。

注意

雖然前向傳播的實現需要在該函式中定義,但之後應呼叫 Module 例項而不是直接呼叫此函式,因為前者負責執行已註冊的鉤子,而後者則會靜默忽略它們。

step(frames: int = 1) None[原始碼]

一次 epsilon 衰減。

在此方法被呼叫 self.annealing_num_steps 次後,後續呼叫將無效。

引數:

frames (int, 可選) – 自上次步進以來的幀數。預設為 1

文件

訪問 PyTorch 的全面開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深入教程

檢視教程

資源

查詢開發資源並解答疑問

檢視資源