TensorDictModule¶
- class tensordict.nn.TensorDictModule(*args, **kwargs)¶
TensorDictModule 是一個 Python 包裝器,它包裝了
nn.Module,用於從 TensorDict 中讀取和寫入資料。- 引數:
module (Callable) – 一個可呼叫物件,通常是
torch.nn.Module,用於將輸入對映到輸出引數空間。其 forward 方法可以返回單個張量、張量元組甚至字典。在後一種情況下,TensorDictModule的輸出鍵將用於填充輸出 tensordict(即out_keys中存在的鍵應存在於moduleforward 方法返回的字典中)。in_keys (iterable of NestedKeys, Dict[NestedStr, str]) – 要從輸入 tensordict 中讀取並傳遞給 module 的鍵。如果它包含多個元素,這些值將按照 in_keys 可迭代物件給出的順序傳遞。如果
in_keys是一個字典,它的鍵必須對應於 tensordict 中要讀取的鍵,而它的值必須匹配函式簽名中的關鍵字引數名稱。如果 out_to_in_map 為True,則對映被反轉,使得鍵對應於函式簽名中的關鍵字引數。out_keys (iterable of str) – 要寫入輸入 tensordict 的鍵。out_keys 的長度必須與嵌入模組返回的張量數量匹配。使用 “_” 作為鍵可以避免將張量寫入輸出。
- 關鍵字引數:
out_to_in_map (bool, optional) –
如果為
True,則 in_keys 被讀取時,其鍵被視為forward()方法的引數鍵,值則是輸入TensorDict中的鍵。如果為False或None(預設),則鍵被視為輸入鍵,值被視為方法的引數鍵。警告
out_to_in_map 的預設值將在 v0.9 版本中從
False更改為True。inplace (bool or string, optional) –
如果為
True(預設),模組的輸出將寫入提供給forward()方法的 tensordict 中。如果為False,則會建立一個新的TensorDict例項,其批大小為空且沒有裝置。如果為"empty",將使用empty()建立輸出 tensordict。注意
如果
inplace=False並且傳遞給模組的 tensordict 是TensorDict以外的TensorDictBase子類,輸出仍將是TensorDict例項。它的批大小將為空,並且沒有裝置。設定為"empty"可以獲得相同的TensorDictBase子型別、相同的批大小和裝置。在執行時使用tensordict_out(見下文)可以對輸出進行更細粒度的控制。注意
如果
inplace=False並且 tensordict_out 被傳遞給forward()方法,則tensordict_out將優先。這是獲取傳遞給模組的 tensordict 是TensorDictBase子類(而不是TensorDict)的 tensordict_out 的方法,輸出仍將是TensorDict例項。
在 TensorDictModule 中嵌入神經網路只需要指定輸入和輸出鍵。TensorDictModule 支援函式式和常規的
nn.Module物件。在函式式情況下,必須指定 'params'(和 'buffers')關鍵字引數示例
>>> from tensordict import TensorDict >>> # one can wrap regular nn.Module >>> module = TensorDictModule(nn.Transformer(128), in_keys=["input", "tgt"], out_keys=["out"]) >>> input = torch.ones(2, 3, 128) >>> tgt = torch.zeros(2, 3, 128) >>> data = TensorDict({"input": input, "tgt": tgt}, batch_size=[2, 3]) >>> data = module(data) >>> print(data) TensorDict( fields={ input: Tensor(shape=torch.Size([2, 3, 128]), device=cpu, dtype=torch.float32, is_shared=False), out: Tensor(shape=torch.Size([2, 3, 128]), device=cpu, dtype=torch.float32, is_shared=False), tgt: Tensor(shape=torch.Size([2, 3, 128]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([2, 3]), device=None, is_shared=False)
我們也可以直接傳遞張量
示例
>>> out = module(input, tgt) >>> assert out.shape == input.shape >>> # we can also wrap regular functions >>> module = TensorDictModule(lambda x: (x-1, x+1), in_keys=[("input", "x")], out_keys=[("output", "x-1"), ("output", "x+1")]) >>> module(TensorDict({("input", "x"): torch.zeros(())}, batch_size=[])) TensorDict( fields={ input: TensorDict( fields={ x: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False), output: TensorDict( fields={ x+1: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), x-1: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)
我們可以使用 TensorDictModule 來填充 tensordict
示例
>>> module = TensorDictModule(lambda: torch.randn(3), in_keys=[], out_keys=["x"]) >>> print(module(TensorDict({}, batch_size=[]))) TensorDict( fields={ x: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)
另一個特性是傳遞字典作為輸入鍵,以控制值到特定關鍵字引數的分派。
示例
>>> module = TensorDictModule(lambda x, *, y: x+y, ... in_keys={'1': 'x', '2': 'y'}, out_keys=['z'], out_to_in_map=False ... ) >>> td = module(TensorDict({'1': torch.ones(()), '2': torch.ones(())*2}, [])) >>> td['z'] tensor(3.)
如果 out_to_in_map 設定為
True,則 in_keys 對映將被反轉。這樣,同一個輸入鍵可以用於不同的關鍵字引數。示例
>>> module = TensorDictModule(lambda x, *, y, z: x+y+z, ... in_keys={'x': '1', 'y': '2', z: '2'}, out_keys=['t'], out_to_in_map=True ... ) >>> td = module(TensorDict({'1': torch.ones(()), '2': torch.ones(())*2}, [])) >>> td['t'] tensor(5.)
tensordict 模組的函式式呼叫很簡單
示例
>>> import torch >>> from tensordict import TensorDict >>> from tensordict.nn import TensorDictModule >>> td = TensorDict({"input": torch.randn(3, 4), "hidden": torch.randn(3, 8)}, [3,]) >>> module = torch.nn.GRUCell(4, 8) >>> td_module = TensorDictModule( ... module=module, in_keys=["input", "hidden"], out_keys=["output"] ... ) >>> params = TensorDict.from_module(td_module) >>> # functional API >>> with params.to_module(td_module): ... td_functional = td_module(td.clone()) >>> print(td_functional) TensorDict( fields={ hidden: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False), input: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False), output: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([3]), device=None, is_shared=False)
- 在有狀態情況下
>>> module = torch.nn.GRUCell(4, 8) >>> td_module = TensorDictModule( ... module=module, in_keys=["input", "hidden"], out_keys=["output"] ... ) >>> td_stateful = td_module(td.clone()) >>> print(td_stateful) TensorDict( fields={ hidden: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False), input: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False), output: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([3]), device=None, is_shared=False)
- forward(tensordict: TensorDictBase = None, args=None, *, tensordict_out: tensordict.base.TensorDictBase | None = None, **kwargs: Any) TensorDictBase¶
當未設定 tensordict 引數時,kwargs 用於建立 TensorDict 例項。