快捷方式

TensorDictSequential

class tensordict.nn.TensorDictSequential(*args, **kwargs)

TensorDictModules 的序列。

類似於 nn.Sequence,它透過一個鏈式對映傳遞張量,每個對映讀取並寫入一個張量,此模組將透過查詢每個輸入模組來讀寫 tensordict。當使用函式式模組呼叫 TensorDictSequencial 例項時,引數列表(和緩衝區)預計會被連線到單個列表中。

引數:

模組 (OrderedDict[str, Callable[[TensorDictBase], TensorDictBase]] | List[Callable[[TensorDictBase], TensorDictBase]]) – 有序的可呼叫物件序列,它們以 TensorDictBase 作為輸入並返回 TensorDictBase。這些可以是 TensorDictModuleBase 的例項,或任何符合此簽名的其他函式。請注意,如果使用了非 TensorDictModuleBase 的可呼叫物件,其輸入和輸出鍵將不會被跟蹤,因此不會影響 TensorDictSequential 的 in_keysout_keys 屬性。常規的 dict 輸入如有必要將被轉換為 OrderedDict

關鍵字引數:
  • partial_tolerant (bool, 可選) – 如果為 True,輸入的 tensordict 可以缺少某些輸入鍵。在這種情況下,將只執行那些根據現有鍵可以執行的模組。此外,如果輸入的 tensordict 是 tensordict 的惰性堆疊,並且 partial_tolerant 為 True,並且堆疊中缺少必需的鍵,那麼 TensorDictSequential 將掃描子 tensordict,查詢是否存在具有必需鍵的子 tensordict。預設為 False。

  • selected_out_keys (巢狀鍵的可迭代物件, 可選) – 要選擇的輸出鍵列表。如果未提供,將寫入所有 out_keys

注意

一個 TensorDictSequential 例項可能有很多輸出鍵,出於清晰性或記憶體目的,可能希望在執行後移除其中一些鍵。如果出現這種情況,可以在例項化後使用方法 select_out_keys(),或者將 selected_out_keys 傳遞給建構函式。

示例

>>> import torch
>>> from tensordict import TensorDict
>>> from tensordict.nn import TensorDictModule, TensorDictSequential
>>> torch.manual_seed(0)
>>> module = TensorDictSequential(
...     TensorDictModule(lambda x: x+1, in_keys=["x"], out_keys=["x+1"]),
...     TensorDictModule(nn.Linear(3, 4), in_keys=["x+1"], out_keys=["w*(x+1)+b"]),
... )
>>> # with tensordict input
>>> print(module(TensorDict({"x": torch.zeros(3)}, [])))
TensorDict(
    fields={
        w*(x+1)+b: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False),
        x+1: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        x: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
>>> # with tensor input: returns all the output keys in the order of the modules, ie "x+1" and "w*(x+1)+b"
>>> module(x=torch.zeros(3))
(tensor([1., 1., 1.]), tensor([-0.7214, -0.8748,  0.1571, -0.1138], grad_fn=<AddBackward0>))
>>> module(torch.zeros(3))
(tensor([1., 1., 1.]), tensor([-0.7214, -0.8748,  0.1571, -0.1138], grad_fn=<AddBackward0>))

TensorDictSequence 支援函式式、模組化和 vmap 程式設計。

示例

>>> import torch
>>> from tensordict import TensorDict
>>> from tensordict.nn import (
...     ProbabilisticTensorDictModule,
...     ProbabilisticTensorDictSequential,
...     TensorDictModule,
...     TensorDictSequential,
... )
>>> from tensordict.nn.distributions import NormalParamExtractor
>>> from tensordict.nn.functional_modules import make_functional
>>> from torch.distributions import Normal
>>> td = TensorDict({"input": torch.randn(3, 4)}, [3,])
>>> net1 = torch.nn.Linear(4, 8)
>>> module1 = TensorDictModule(net1, in_keys=["input"], out_keys=["params"])
>>> normal_params = TensorDictModule(
...      NormalParamExtractor(), in_keys=["params"], out_keys=["loc", "scale"]
...  )
>>> td_module1 = ProbabilisticTensorDictSequential(
...     module1,
...     normal_params,
...     ProbabilisticTensorDictModule(
...         in_keys=["loc", "scale"],
...         out_keys=["hidden"],
...         distribution_class=Normal,
...         return_log_prob=True,
...     )
... )
>>> module2 = torch.nn.Linear(4, 8)
>>> td_module2 = TensorDictModule(
...    module=module2, in_keys=["hidden"], out_keys=["output"]
... )
>>> td_module = TensorDictSequential(td_module1, td_module2)
>>> params = TensorDict.from_module(td_module)
>>> with params.to_module(td_module):
...     _ = td_module(td)
>>> print(td)
TensorDict(
    fields={
        hidden: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        input: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        loc: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        output: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False),
        params: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False),
        sample_log_prob: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        scale: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)
在 vmap 的情況下
>>> from torch import vmap
>>> params = params.expand(4)
>>> def func(td, params):
...     with params.to_module(td_module):
...         return td_module(td)
>>> td_vmap = vmap(func, (None, 0))(td, params)
>>> print(td_vmap)
TensorDict(
    fields={
        hidden: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        input: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        loc: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        output: Tensor(shape=torch.Size([4, 3, 8]), device=cpu, dtype=torch.float32, is_shared=False),
        params: Tensor(shape=torch.Size([4, 3, 8]), device=cpu, dtype=torch.float32, is_shared=False),
        sample_log_prob: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        scale: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([4, 3]),
    device=None,
    is_shared=False)
forward(tensordict: TensorDictBase = None, tensordict_out: tensordict.base.TensorDictBase | None = None, **kwargs: Any) TensorDictBase

如果未設定 tensordict 引數,則使用 kwargs 建立 TensorDict 例項。

reset_out_keys()

out_keys 屬性重置為其原始值。

返回值:同一個模組,其 out_keys 值恢復為原始值。

示例

>>> from tensordict import TensorDict
>>> from tensordict.nn import TensorDictModule, TensorDictSequential
>>> import torch
>>> mod = TensorDictModule(lambda x, y: (x+2, y+2), in_keys=["a", "b"], out_keys=["c", "d"])
>>> mod.select_out_keys("d")
>>> td = TensorDict({"a": torch.zeros(()), "b": torch.ones(())}, [])
>>> mod(td)
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        d: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
>>> mod.reset_out_keys()
>>> mod(td)
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        c: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        d: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
select_out_keys(*selected_out_keys) TensorDictSequential

選擇將在輸出 tensordict 中找到的鍵。

這在想要移除複雜圖中的中間鍵,或者當這些鍵的存在可能引發意外行為時很有用。

原始的 out_keys 仍然可以透過 module.out_keys_source 訪問。

引數:

*out_keys (字串序列字串元組) – 應在輸出 tensordict 中找到的輸出鍵。

返回值:同一個模組,已就地修改並更新了 out_keys

最簡單的用法是結合 TensorDictModule 使用。

示例

>>> from tensordict import TensorDict
>>> from tensordict.nn import TensorDictModule, TensorDictSequential
>>> import torch
>>> mod = TensorDictModule(lambda x, y: (x+2, y+2), in_keys=["a", "b"], out_keys=["c", "d"])
>>> td = TensorDict({"a": torch.zeros(()), "b": torch.ones(())}, [])
>>> mod(td)
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        c: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        d: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
>>> mod.select_out_keys("d")
>>> td = TensorDict({"a": torch.zeros(()), "b": torch.ones(())}, [])
>>> mod(td)
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        d: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

此功能也適用於分派的引數: 示例

>>> mod(torch.zeros(()), torch.ones(()))
tensor(2.)

此更改將就地發生 (即返回同一個模組,但 out_keys 列表已更新)。可以使用 TensorDictModuleBase.reset_out_keys() 方法恢復此更改。

示例

>>> mod.reset_out_keys()
>>> mod(TensorDict({"a": torch.zeros(()), "b": torch.ones(())}, []))
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        c: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        d: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

這也適用於其他類,例如 Sequential: 示例

>>> from tensordict.nn import TensorDictSequential
>>> seq = TensorDictSequential(
...     TensorDictModule(lambda x: x+1, in_keys=["x"], out_keys=["y"]),
...     TensorDictModule(lambda x: x+1, in_keys=["y"], out_keys=["z"]),
... )
>>> td = TensorDict({"x": torch.zeros(())}, [])
>>> seq(td)
TensorDict(
    fields={
        x: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        y: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        z: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
>>> seq.select_out_keys("z")
>>> td = TensorDict({"x": torch.zeros(())}, [])
>>> seq(td)
TensorDict(
    fields={
        x: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        z: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
select_subsequence(in_keys: Optional[Iterable[NestedKey]] = None, out_keys: Optional[Iterable[NestedKey]] = None) TensorDictSequential

返回一個新的 TensorDictSequential,其中僅包含計算給定輸入鍵的給定輸出鍵所需的模組。

引數:
  • in_keys – 我們要選擇的子序列的輸入鍵。所有不在 in_keys 中的鍵將被視為不相關,並且 *僅* 將這些鍵作為輸入的模組將被丟棄。生成的 sequential 模組將遵循模式“所有模組的輸出會因任何在 中的鍵的不同值而受到影響”。如果未提供,則假定使用模組的 in_keys

  • out_keys – 我們要選擇的子序列的輸出鍵。生成的序列中將只包含獲取 out_keys 所必需的模組。生成的 sequential 模組將遵循模式“所有對 條目的值構成條件的模組。”如果未提供,則假定使用模組的 out_keys

返回值:

一個新的 TensorDictSequential,其中僅包含根據給定的輸入和輸出鍵所需的模組。

示例

>>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod
>>> idn = lambda x: x
>>> module = Seq(
...     Mod(idn, in_keys=["a"], out_keys=["b"]),
...     Mod(idn, in_keys=["b"], out_keys=["c"]),
...     Mod(idn, in_keys=["c"], out_keys=["d"]),
...     Mod(idn, in_keys=["a"], out_keys=["e"]),
... )
>>> # select all modules whose output depend on "a"
>>> module.select_subsequence(in_keys=["a"])
TensorDictSequential(
    module=ModuleList(
      (0): TensorDictModule(
          module=<function <lambda> at 0x126ed1ca0>,
          device=cpu,
          in_keys=['a'],
          out_keys=['b'])
      (1): TensorDictModule(
          module=<function <lambda> at 0x126ed1ca0>,
          device=cpu,
          in_keys=['b'],
          out_keys=['c'])
      (2): TensorDictModule(
          module=<function <lambda> at 0x126ed1ca0>,
          device=cpu,
          in_keys=['c'],
          out_keys=['d'])
      (3): TensorDictModule(
          module=<function <lambda> at 0x126ed1ca0>,
          device=cpu,
          in_keys=['a'],
          out_keys=['e'])
    ),
    device=cpu,
    in_keys=['a'],
    out_keys=['b', 'c', 'd', 'e'])
>>> # select all modules whose output depend on "c"
>>> module.select_subsequence(in_keys=["c"])
TensorDictSequential(
    module=ModuleList(
      (0): TensorDictModule(
          module=<function <lambda> at 0x126ed1ca0>,
          device=cpu,
          in_keys=['c'],
          out_keys=['d'])
    ),
    device=cpu,
    in_keys=['c'],
    out_keys=['d'])
>>> # select all modules that affect the value of "c"
>>> module.select_subsequence(out_keys=["c"])
TensorDictSequential(
    module=ModuleList(
      (0): TensorDictModule(
          module=<function <lambda> at 0x126ed1ca0>,
          device=cpu,
          in_keys=['a'],
          out_keys=['b'])
      (1): TensorDictModule(
          module=<function <lambda> at 0x126ed1ca0>,
          device=cpu,
          in_keys=['b'],
          out_keys=['c'])
    ),
    device=cpu,
    in_keys=['a'],
    out_keys=['b', 'c'])
>>> # select all modules that affect the value of "e"
>>> module.select_subsequence(out_keys=["e"])
TensorDictSequential(
    module=ModuleList(
      (0): TensorDictModule(
          module=<function <lambda> at 0x126ed1ca0>,
          device=cpu,
          in_keys=['a'],
          out_keys=['e'])
    ),
    device=cpu,
    in_keys=['a'],
    out_keys=['e'])

此方法會傳播到巢狀的 sequential

>>> module = Seq(
...     Seq(
...         Mod(idn, in_keys=["a"], out_keys=["b"]),
...         Mod(idn, in_keys=["b"], out_keys=["c"]),
...     ),
...     Seq(
...         Mod(idn, in_keys=["b"], out_keys=["d"]),
...         Mod(idn, in_keys=["d"], out_keys=["e"]),
...     ),
... )
>>> # select submodules whose output will be affected by a change in "b" or "d" AND which output is "e"
>>> module.select_subsequence(in_keys=["b", "d"], out_keys=["e"])
TensorDictSequential(
    module=ModuleList(
      (0): TensorDictSequential(
          module=ModuleList(
            (0): TensorDictModule(
                module=<function <lambda> at 0x129efae50>,
                device=cpu,
                in_keys=['b'],
                out_keys=['d'])
            (1): TensorDictModule(
                module=<function <lambda> at 0x129efae50>,
                device=cpu,
                in_keys=['d'],
                out_keys=['e'])
          ),
          device=cpu,
          in_keys=['b'],
          out_keys=['d', 'e'])
    ),
    device=cpu,
    in_keys=['b'],
    out_keys=['d', 'e'])

文件

訪問 PyTorch 的全面開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深入教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源