快捷方式

LSTMModule

class torchrl.modules.LSTMModule(*args, **kwargs)[source]

一個用於 LSTM 模組的嵌入器。

該類為 torch.nn.LSTM 添加了以下功能:

  • 與 TensorDict 的相容性:隱狀態被重塑以匹配 tensordict 的批次大小。

  • 可選的多步執行:使用 torch.nn 時,必須在 torch.nn.LSTMCelltorch.nn.LSTM 之間選擇,前者相容單步輸入,後者相容多步輸入。該類同時支援這兩種用法。

構建後,模組 設定為迴圈模式,即它將預期單步輸入。

如果在迴圈模式下,tensordict 的最後一個維度預計表示步數。tensordict 的維度沒有限制(除了對於時間輸入來說必須大於一)。

注意

該類可以處理沿時間維度的多個連續軌跡, 在這些情況下,不應信任最終的隱狀態值(即,它們不應被用於連續軌跡)。原因是 LSTM 只返回最後一個隱狀態值,對於我們提供的填充輸入,該值可能對應於一個填充零的輸入。

引數:
  • input_size – 輸入 x 中預期的特徵數量

  • hidden_size – 隱狀態 h 中的特徵數量

  • num_layers – 迴圈層的數量。例如,設定 num_layers=2 意味著將兩個 LSTM 堆疊在一起形成一個 堆疊 LSTM,第二個 LSTM 接收第一個 LSTM 的輸出並計算最終結果。預設值:1

  • bias – 如果為 False,則該層不使用偏置權重 b_ihb_hh。預設值:True

  • dropout – 如果非零,則在除最後一層以外的每個 LSTM 層的輸出上引入一個 Dropout 層,dropout 機率等於 dropout。預設值:0

  • python_based – 如果為 True,將使用 LSTM 單元的完整 Python 實現。預設值:False

關鍵字引數:
  • in_key (strstr 元組) – 模組的輸入鍵。與 in_keys 互斥。如果提供,迴圈鍵假定為 [“recurrent_state_h”, “recurrent_state_c”],in_key 將被附加在它們之前。

  • in_keys (str 列表) – 對應於輸入值、第一個和第二個隱狀態鍵的字串三元組。與 in_key 互斥。

  • out_key (strstr 元組) – 模組的輸出鍵。與 out_keys 互斥。如果提供,迴圈鍵假定為 [(“next”, “recurrent_state_h”), (“next”, “recurrent_state_c”)],out_key 將被附加在它們之前。

  • out_keys (str 列表) –

    對應於輸出值、第一個和第二個隱狀態鍵的字串三元組。.. 注意

    For a better integration with TorchRL's environments, the best naming
    for the output hidden key is ``("next", <custom_key>)``, such
    that the hidden values are passed from step to step during a rollout.
    

  • device (torch.device相容型別) – 模組的裝置。

  • lstm (torch.nn.LSTM, 可選) – 要封裝的 LSTM 例項。與其他 nn.LSTM 引數互斥。

  • default_recurrent_mode (bool, 可選) – 如果提供,則指定迴圈模式,除非已被 set_recurrent_mode 上下文管理器/裝飾器覆蓋。預設值為 False

變數:

recurrent_mode – 返回模組的迴圈模式。

set_recurrent_mode()[source]

控制模組是否應在迴圈模式下執行。

make_tensordict_primer()[source]

建立 TensorDictPrimer transforms,以便環境感知 RNN 的迴圈狀態。

注意

該模組依賴於輸入 TensorDict 中存在特定的 recurrent_state 鍵。要生成一個 TensorDictPrimer transform,該 transform 將自動把隱狀態新增到環境 TensorDict 中,請使用方法 make_tensordict_primer()。如果該類是一個更大模組的子模組,可以在父模組上呼叫方法 get_primers_from_module(),以自動生成包括本模組在內的所有子模組所需的 primer transforms。

示例

>>> from torchrl.envs import TransformedEnv, InitTracker
>>> from torchrl.envs import GymEnv
>>> from torchrl.modules import MLP
>>> from torch import nn
>>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod
>>> env = TransformedEnv(GymEnv("Pendulum-v1"), InitTracker())
>>> lstm_module = LSTMModule(
...     input_size=env.observation_spec["observation"].shape[-1],
...     hidden_size=64,
...     in_keys=["observation", "rs_h", "rs_c"],
...     out_keys=["intermediate", ("next", "rs_h"), ("next", "rs_c")])
>>> mlp = MLP(num_cells=[64], out_features=1)
>>> policy = Seq(lstm_module, Mod(mlp, in_keys=["intermediate"], out_keys=["action"]))
>>> policy(env.reset())
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
        done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        intermediate: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.float32, is_shared=False),
        is_init: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                rs_c: Tensor(shape=torch.Size([1, 64]), device=cpu, dtype=torch.float32, is_shared=False),
                rs_h: Tensor(shape=torch.Size([1, 64]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([]),
            device=cpu,
            is_shared=False),
        observation: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)},
        terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([]),
    device=cpu,
    is_shared=False)
forward(tensordict: TensorDictBase = None)[source]

定義每次呼叫時執行的計算。

應被所有子類覆蓋。

注意

儘管前向傳播的實現需要在該函式中定義,但之後應呼叫 Module 例項而不是該函式,因為前者負責執行已註冊的鉤子,而後者會靜默地忽略它們。

make_cudnn_based() LSTMModule[source]

將 LSTM 層轉換為基於 CuDNN 的版本。

返回值:

自身

make_python_based() LSTMModule[source]

將 LSTM 層轉換為基於 Python 的版本。

返回值:

自身

make_tensordict_primer()[source]

為環境建立一個 tensordict primer。

一個 TensorDictPrimer 物件將確保策略在 rollout 執行期間感知補充輸入和輸出(迴圈狀態)。這樣,資料可以在不同程序之間共享並得到妥善處理。

當使用批處理環境(例如 ParallelEnv)時,transform 可以在單個環境例項級別使用(即,一批內部設定了 tensordict primer 的轉換環境),也可以在批處理環境例項級別使用(即,一個轉換過的常規環境批次)。

不在環境中包含 TensorDictPrimer 可能導致行為定義不清,例如在並行設定中,一步涉及將新的迴圈狀態從 "next" 複製到根 tensordict,而 meth:~torchrl.EnvBase.step_mdp 方法將無法完成此操作,因為迴圈狀態未在環境規範中註冊。

參閱 torchrl.modules.utils.get_primers_from_module(),瞭解生成給定模組所有 primer 的方法。

示例

>>> from torchrl.collectors import SyncDataCollector
>>> from torchrl.envs import TransformedEnv, InitTracker
>>> from torchrl.envs import GymEnv
>>> from torchrl.modules import MLP, LSTMModule
>>> from torch import nn
>>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod
>>>
>>> env = TransformedEnv(GymEnv("Pendulum-v1"), InitTracker())
>>> lstm_module = LSTMModule(
...     input_size=env.observation_spec["observation"].shape[-1],
...     hidden_size=64,
...     in_keys=["observation", "rs_h", "rs_c"],
...     out_keys=["intermediate", ("next", "rs_h"), ("next", "rs_c")])
>>> mlp = MLP(num_cells=[64], out_features=1)
>>> policy = Seq(lstm_module, Mod(mlp, in_keys=["intermediate"], out_keys=["action"]))
>>> policy(env.reset())
>>> env = env.append_transform(lstm_module.make_tensordict_primer())
>>> data_collector = SyncDataCollector(
...     env,
...     policy,
...     frames_per_batch=10
... )
>>> for data in data_collector:
...     print(data)
...     break
set_recurrent_mode(mode: bool = True)[source]

[已棄用 - 請改用 torchrl.modules.set_recurrent_mode 上下文管理器] 返回模組的新副本,該副本共享相同的 lstm 模型,但具有不同的 recurrent_mode 屬性(如果不同)。

建立副本是為了使模組可以在程式碼的不同部分(推斷 vs 訓練)以不同的行為方式使用。

示例

>>> from torchrl.envs import TransformedEnv, InitTracker, step_mdp
>>> from torchrl.envs import GymEnv
>>> from torchrl.modules import MLP
>>> from tensordict import TensorDict
>>> from torch import nn
>>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod
>>> env = TransformedEnv(GymEnv("Pendulum-v1"), InitTracker())
>>> lstm = nn.LSTM(input_size=env.observation_spec["observation"].shape[-1], hidden_size=64, batch_first=True)
>>> lstm_module = LSTMModule(lstm=lstm, in_keys=["observation", "hidden0", "hidden1"], out_keys=["intermediate", ("next", "hidden0"), ("next", "hidden1")])
>>> mlp = MLP(num_cells=[64], out_features=1)
>>> # building two policies with different behaviors:
>>> policy_inference = Seq(lstm_module, Mod(mlp, in_keys=["intermediate"], out_keys=["action"]))
>>> policy_training = Seq(lstm_module.set_recurrent_mode(True), Mod(mlp, in_keys=["intermediate"], out_keys=["action"]))
>>> traj_td = env.rollout(3) # some random temporal data
>>> traj_td = policy_training(traj_td)
>>> # let's check that both return the same results
>>> td_inf = TensorDict(batch_size=traj_td.shape[:-1])
>>> for td in traj_td.unbind(-1):
...     td_inf = td_inf.update(td.select("is_init", "observation", ("next", "observation")))
...     td_inf = policy_inference(td_inf)
...     td_inf = step_mdp(td_inf)
...
>>> torch.testing.assert_close(td_inf["hidden0"], traj_td[..., -1]["next", "hidden0"])

文件

查閱 PyTorch 的全面開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深入教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源