WrapModule¶
- class tensordict.nn.WrapModule(*args, **kwargs)¶
一個用於包裝處理 TensorDict 例項的任意可呼叫物件的包裝器。
當構建
TensorDictSequential棧時,以及當轉換需要整個 TensorDict 例項可見時,這個包裝器非常有用。- 引數:
func (Callable[[TensorDictBase], TensorDictBase]) – 一個可呼叫函式,接受一個 TensorDictBase 例項並返回一個轉換後的 TensorDictBase 例項。
- 關鍵字引數:
inplace (bool, optional) – 如果為
True,則原地修改輸入的 TensorDict。否則,將返回一個新的 TensorDict(如果函式未原地修改並返回它)。預設為False。in_keys (list of NestedKey, optional) – 如果提供,表示模組讀取哪些條目。這不會被檢查,僅用於通知
TensorDictSequential關於包裝模組的輸入鍵。預設為 []。out_keys (list of NestedKey, optional) – 如果提供,表示模組寫入哪些條目。這不會被檢查,僅用於通知
TensorDictSequential關於包裝模組的輸出鍵。預設為 []。
示例
>>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod, WrapModule >>> seq = Seq( ... Mod(lambda x: x * 2, in_keys=["x"], out_keys=["y"]), ... WrapModule(lambda td: td.reshape(-1)), ... ) >>> td = TensorDict(x=torch.ones(3, 4, 5), batch_size=[3, 4]) >>> td = Seq(td) >>> assert td.shape == (12,) >>> assert (td["y"] == 2).all() >>> assert td["y"].shape == (12, 5)
- forward(data: TensorDictBase) TensorDictBase¶
定義每次呼叫時執行的計算。
應由所有子類覆蓋。
注意
儘管正向傳播的實現需要在該函式內部定義,但後續應呼叫
Module例項而非此函式本身,因為前者負責執行註冊的鉤子,而後者則會默默忽略它們。