ConsistentDropoutModule¶
- class torchrl.modules.ConsistentDropoutModule(*args, **kwargs)[source]¶
ConsistentDropout 的 TensorDictModule 包裝器。
- 引數:
p (
float, 可選) – Dropout 機率。預設值:0.5。in_keys (NestedKey 或 NestedKey 列表) – 要從輸入 tensordict 讀取並傳遞給此模組的鍵。
out_keys (NestedKey 或 NestedKey 可迭代物件) – 要寫入輸入 tensordict 的鍵。預設為
in_keys的值。
- 關鍵字引數:
input_shape (
tuple, 可選) – 輸入(非批次)的形狀,用於使用make_tensordict_primer()生成 tensordict primers。input_dtype (torch.dtype, 可選) – primer 輸入的資料型別。如果未傳遞,則假定為
torch.get_default_dtype。
注意
要在策略中使用此類,需要在重置時重置掩碼。這可以透過使用
make_tensordict_primer()獲取的TensorDictPrimer變換來實現。有關更多資訊,請參閱此方法。示例
>>> from tensordict import TensorDict >>> module = ConsistentDropoutModule(p = 0.1) >>> td = TensorDict({"x": torch.randn(3, 4)}, [3]) >>> module(td) TensorDict( fields={ mask_6127171760: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.bool, is_shared=False), x: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([3]), device=None, is_shared=False)
- forward(tensordict)[source]¶
定義每次呼叫時執行的計算。
應由所有子類覆蓋。
注意
儘管前向傳播的實現需要在該函式中定義,但之後應該呼叫
Module例項而不是直接呼叫此函式,因為前者負責執行註冊的鉤子,而後者會默默忽略它們。
- make_tensordict_primer()[source]¶
為環境建立一個 tensordict primer,以便在重置呼叫期間生成隨機掩碼。
另請參閱
模組的所有 primer 的方法。
示例
>>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod >>> from torchrl.envs import GymEnv, StepCounter, SerialEnv >>> m = Seq( ... Mod(torch.nn.Linear(7, 4), in_keys=["observation"], out_keys=["intermediate"]), ... ConsistentDropoutModule( ... p=0.5, ... input_shape=(2, 4), ... in_keys="intermediate", ... ), ... Mod(torch.nn.Linear(4, 7), in_keys=["intermediate"], out_keys=["action"]), ... ) >>> primer = get_primers_from_module(m) >>> env0 = GymEnv("Pendulum-v1").append_transform(StepCounter(5)) >>> env1 = GymEnv("Pendulum-v1").append_transform(StepCounter(6)) >>> env = SerialEnv(2, [lambda env=env0: env, lambda env=env1: env]) >>> env = env.append_transform(primer) >>> r = env.rollout(10, m, break_when_any_done=False) >>> mask = [k for k in r.keys() if k.startswith("mask")][0] >>> assert (r[mask][0, :5] != r[mask][0, 5:6]).any() >>> assert (r[mask][0, :4] == r[mask][0, 4:5]).all()