GRUModule¶
- class torchrl.modules.GRUModule(*args, **kwargs)[source]¶
GRU 模組的嵌入器。
此類別為
torch.nn.GRU添加了以下功能:與 TensorDict 的相容性:隱藏狀態會被重塑以匹配 tensordict 的批次大小。
可選的多步執行:使用 torch.nn 時,必須在
torch.nn.GRUCell和torch.nn.GRU之間進行選擇,前者與單步輸入相容,後者與多步相容。此類別同時支援這兩種用法。
構建後,模組預設**不**處於迴圈模式,即它會期望單步輸入。
如果在迴圈模式下,tensordict 的最後一個維度預期表示步數。tensordict 的維度沒有限制(除了對於時間序列輸入,維度必須大於一)。
- 引數:
input_size – 輸入 x 中期望的特徵數量
hidden_size – 隱藏狀態 h 中的特徵數量
num_layers – 迴圈層數量。例如,設定
num_layers=2意味著將兩個 GRU 堆疊在一起形成一個 堆疊 GRU,其中第二個 GRU 接收第一個 GRU 的輸出並計算最終結果。預設值:1bias – 如果為
False,則層不使用偏置權重。預設值:Truedropout – 如果非零,則在除最後一層外的每個 GRU 層的輸出上引入一個 Dropout 層,丟棄機率等於
dropout。預設值:0python_based – 如果為
True,將使用完整的 Python 實現的 GRU Cell。預設值:False
- 關鍵字引數:
in_key (str 或 tuple of str) – 模組的輸入鍵。與
in_keys互斥使用。如果提供,迴圈鍵假定為 [“recurrent_state”],並且in_key將在此之前新增。in_keys (list of str) – 對應於輸入值和迴圈條目的字串對。與
in_key互斥。out_key (str 或 tuple of str) – 模組的輸出鍵。與
out_keys互斥使用。如果提供,迴圈鍵假定為 [(“recurrent_state”)],並且out_key將在這些鍵之前新增。out_keys (list of str) –
對應於輸出值和隱藏狀態的字串對。
For a better integration with TorchRL's environments, the best naming for the output hidden key is ``("next", <custom_key>)``, such that the hidden values are passed from step to step during a rollout.device (torch.device 或 相容型別) – 模組的裝置。
gru (torch.nn.GRU, 可選) – 要包裝的 GRU 例項。與其他 nn.GRU 引數互斥。
default_recurrent_mode (bool, 可選) – 如果提供,則為未被
set_recurrent_mode上下文管理器/裝飾器覆蓋時的迴圈模式。預設為False。
- 變數:
recurrent_mode – 返回模組的迴圈模式。
注意
此模組依賴於輸入 TensorDict 中存在特定的
recurrent_state鍵。要生成一個會自動向環境 TensorDict 新增隱藏狀態的TensorDictPrimer變換,請使用方法make_tensordict_primer()。如果此類別是較大模組的子模組,則可以在父模組上呼叫方法get_primers_from_module(),以自動生成所有子模組(包括此模組)所需的 primer 變換。示例
>>> from torchrl.envs import TransformedEnv, InitTracker >>> from torchrl.envs import GymEnv >>> from torchrl.modules import MLP >>> from torch import nn >>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod >>> env = TransformedEnv(GymEnv("Pendulum-v1"), InitTracker()) >>> gru_module = GRUModule( ... input_size=env.observation_spec["observation"].shape[-1], ... hidden_size=64, ... in_keys=["observation", "rs"], ... out_keys=["intermediate", ("next", "rs")]) >>> mlp = MLP(num_cells=[64], out_features=1) >>> policy = Seq(gru_module, Mod(mlp, in_keys=["intermediate"], out_keys=["action"])) >>> policy(env.reset()) TensorDict( fields={ action: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False), done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), intermediate: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.float32, is_shared=False), is_init: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), next: TensorDict( fields={ rs: Tensor(shape=torch.Size([1, 64]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=cpu, is_shared=False), observation: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([]), device=cpu, is_shared=False) >>> gru_module_training = gru_module.set_recurrent_mode() >>> policy_training = Seq(gru_module, Mod(mlp, in_keys=["intermediate"], out_keys=["action"])) >>> traj_td = env.rollout(3) # some random temporal data >>> traj_td = policy_training(traj_td) >>> print(traj_td) TensorDict( fields={ action: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False), done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), intermediate: Tensor(shape=torch.Size([3, 64]), device=cpu, dtype=torch.float32, is_shared=False), is_init: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), next: TensorDict( fields={ done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), is_init: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), observation: Tensor(shape=torch.Size([3, 3]), device=cpu, dtype=torch.float32, is_shared=False), reward: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False), rs: Tensor(shape=torch.Size([3, 1, 64]), device=cpu, dtype=torch.float32, is_shared=False), terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([3]), device=cpu, is_shared=False), observation: Tensor(shape=torch.Size([3, 3]), device=cpu, dtype=torch.float32, is_shared=False), terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([3]), device=cpu, is_shared=False)
- forward(tensordict: TensorDictBase = None)[source]¶
定義每次呼叫時執行的計算。
應由所有子類覆蓋。
注意
雖然前向傳播的實現需要在該函式中定義,但之後應該呼叫
Module例項而不是此函式本身,因為前者會處理已註冊的鉤子,而後者會靜默忽略它們。
- make_tensordict_primer()[source]¶
為環境建立 tensordict primer。
一個
TensorDictPrimer物件將確保策略在 rollout 執行期間感知到補充的輸入和輸出(迴圈狀態)。這樣,資料可以在程序之間共享並得到妥善處理。如果在環境中不包含
TensorDictPrimer,可能會導致行為定義不清,例如在並行設定中,一步操作涉及將新的迴圈狀態從"next"複製到根 tensordict,而 meth:~torchrl.EnvBase.step_mdp 方法將無法完成此操作,因為迴圈狀態未在環境規範中註冊。使用批次環境(例如
ParallelEnv)時,變換可以在單個環境例項級別(即內部設定了 tensordict primer 的批次變換環境)或批次環境例項級別(即常規環境的變換批次)使用。有關生成給定模組所有 primer 的方法,請參見
torchrl.modules.utils.get_primers_from_module()。示例
>>> from torchrl.collectors import SyncDataCollector >>> from torchrl.envs import TransformedEnv, InitTracker >>> from torchrl.envs import GymEnv >>> from torchrl.modules import MLP, LSTMModule >>> from torch import nn >>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod >>> >>> env = TransformedEnv(GymEnv("Pendulum-v1"), InitTracker()) >>> gru_module = GRUModule( ... input_size=env.observation_spec["observation"].shape[-1], ... hidden_size=64, ... in_keys=["observation", "rs"], ... out_keys=["intermediate", ("next", "rs")]) >>> mlp = MLP(num_cells=[64], out_features=1) >>> policy = Seq(gru_module, Mod(mlp, in_keys=["intermediate"], out_keys=["action"])) >>> policy(env.reset()) >>> env = env.append_transform(gru_module.make_tensordict_primer()) >>> data_collector = SyncDataCollector( ... env, ... policy, ... frames_per_batch=10 ... ) >>> for data in data_collector: ... print(data) ... break
- set_recurrent_mode(mode: bool = True)[source]¶
[已棄用 - 請改用
torchrl.modules.set_recurrent_mode上下文管理器] 返回模組的新副本,該副本共享相同的 gru 模型,但具有不同的recurrent_mode屬性(如果不同)。建立副本是為了使模組可以在程式碼的不同部分(推理與訓練)中以不同的行為方式使用。
示例
>>> from torchrl.envs import GymEnv, TransformedEnv, InitTracker, step_mdp >>> from torchrl.modules import MLP >>> from tensordict import TensorDict >>> from torch import nn >>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod >>> env = TransformedEnv(GymEnv("Pendulum-v1"), InitTracker()) >>> gru = nn.GRU(input_size=env.observation_spec["observation"].shape[-1], hidden_size=64, batch_first=True) >>> gru_module = GRUModule(gru=gru, in_keys=["observation", "hidden"], out_keys=["intermediate", ("next", "hidden")]) >>> mlp = MLP(num_cells=[64], out_features=1) >>> # building two policies with different behaviors: >>> policy_inference = Seq(gru_module, Mod(mlp, in_keys=["intermediate"], out_keys=["action"])) >>> policy_training = Seq(gru_module.set_recurrent_mode(True), Mod(mlp, in_keys=["intermediate"], out_keys=["action"])) >>> traj_td = env.rollout(3) # some random temporal data >>> traj_td = policy_training(traj_td) >>> # let's check that both return the same results >>> td_inf = TensorDict(batch_size=traj_td.shape[:-1]) >>> for td in traj_td.unbind(-1): ... td_inf = td_inf.update(td.select("is_init", "observation", ("next", "observation"))) ... td_inf = policy_inference(td_inf) ... td_inf = step_mdp(td_inf) ... >>> torch.testing.assert_close(td_inf["hidden"], traj_td[..., -1]["next", "hidden"])