快捷方式

TensorDictModuleBase

class tensordict.nn.TensorDictModuleBase(*args, **kwargs)

TensorDict 模組的基類。

TensorDictModule 的子類透過 in_keysout_keys 鍵列表來標識,這些列表指示應讀取哪些輸入條目以及應寫入哪些輸出條目。

forward 方法的輸入/輸出簽名應始終遵循以下約定

>>> tensordict_out = module.forward(tensordict_in)

TensorDictModule 不同,TensorDictModuleBase 通常透過子類化使用:只要子類 forward 方法讀寫 tensordict(或相關型別)例項,您就可以將任何 Python 函式包裝到 TensorDictModuleBase 子類中。

應正確指定 in_keysout_keys。例如,可以使用 select_out_keys() 動態減少 out_keys

示例

>>> from tensordict import TensorDict
>>> from tensordict.nn import TensorDictModuleBase
>>> class Mod(TensorDictModuleBase):
...     in_keys = ["a"] # can also be specified during __init__
...     out_keys = ["b", "c"]
...     def forward(self, tensordict):
...         b = tensordict["a"].clone()
...         c = b + 1
...         return tensordict.replace({"b": b, "c": c})
>>> mod = Mod()
>>> td = mod(TensorDict(a=0))
>>> td["b"]
tensor(0)
>>> td["c"]
tensor(1)
>>> mod.select_out_keys("c")
>>> td = mod(TensorDict(a=0))
>>> td["c"]
tensor(1)
>>> assert "b" not in td
static is_tdmodule_compatible(module)

檢查模組是否與 TensorDictModule API 相容。

reset_out_keys()

out_keys 屬性重置為其原始值。

返回值:具有其原始 out_keys 值的同一模組。

示例

>>> from tensordict import TensorDict
>>> from tensordict.nn import TensorDictModule, TensorDictSequential
>>> import torch
>>> mod = TensorDictModule(lambda x, y: (x+2, y+2), in_keys=["a", "b"], out_keys=["c", "d"])
>>> mod.select_out_keys("d")
>>> td = TensorDict({"a": torch.zeros(()), "b": torch.ones(())}, [])
>>> mod(td)
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        d: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
>>> mod.reset_out_keys()
>>> mod(td)
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        c: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        d: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
reset_parameters_recursive(parameters: Optional[TensorDictBase] = None) Optional[TensorDictBase]

遞迴重置模組及其子模組的引數。

引數:

parameters (引數的 TensorDict,可選) – 如果設定為 None,模組將使用 self.parameters() 進行重置。否則,我們將原地重置 tensordict 中的引數。這對於引數不儲存在模組本身的函式式模組非常有用。

返回值:

新的引數的 tensordict,僅在 parameters 不為 None 時返回。

示例

>>> from tensordict.nn import TensorDictModule
>>> from torch import nn
>>> net = nn.Sequential(nn.Linear(2,3), nn.ReLU())
>>> old_param = net[0].weight.clone()
>>> module = TensorDictModule(net, in_keys=['bork'], out_keys=['dork'])
>>> module.reset_parameters()
>>> (old_param == net[0].weight).any()
tensor(False)

此方法也支援函式式引數取樣

>>> from tensordict import TensorDict
>>> from tensordict.nn import TensorDictModule
>>> from torch import nn
>>> net = nn.Sequential(nn.Linear(2,3), nn.ReLU())
>>> module = TensorDictModule(net, in_keys=['bork'], out_keys=['dork'])
>>> params = TensorDict.from_module(module)
>>> old_params = params.clone(recurse=True)
>>> module.reset_parameters(params)
>>> (old_params == params).any()
False
select_out_keys(*out_keys) TensorDictModuleBase

選擇將在輸出 tensordict 中找到的鍵。

這在需要刪除複雜圖中的中間鍵,或者這些鍵的存在可能引發意外行為時非常有用。

原始的 out_keys 仍然可以透過 module.out_keys_source 訪問。

引數:

*out_keys (字串序列或字串元組) – 應在輸出 tensordict 中找到的 out_keys。

返回值:已就地修改的同一模組,其中 out_keys 已更新。

最簡單的用法是結合 TensorDictModule

示例

>>> from tensordict import TensorDict
>>> from tensordict.nn import TensorDictModule, TensorDictSequential
>>> import torch
>>> mod = TensorDictModule(lambda x, y: (x+2, y+2), in_keys=["a", "b"], out_keys=["c", "d"])
>>> td = TensorDict({"a": torch.zeros(()), "b": torch.ones(())}, [])
>>> mod(td)
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        c: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        d: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
>>> mod.select_out_keys("d")
>>> td = TensorDict({"a": torch.zeros(()), "b": torch.ones(())}, [])
>>> mod(td)
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        d: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

此功能也適用於已分派的引數: .. 標題:示例

>>> mod(torch.zeros(()), torch.ones(()))
tensor(2.)

此更改將原地發生(即返回的仍是同一模組,但 out_keys 列表已更新)。可以使用 TensorDictModuleBase.reset_out_keys() 方法恢復此更改。

示例

>>> mod.reset_out_keys()
>>> mod(TensorDict({"a": torch.zeros(()), "b": torch.ones(())}, []))
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        c: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        d: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

這也適用於其他類,例如 Sequential: .. 標題:示例

>>> from tensordict.nn import TensorDictSequential
>>> seq = TensorDictSequential(
...     TensorDictModule(lambda x: x+1, in_keys=["x"], out_keys=["y"]),
...     TensorDictModule(lambda x: x+1, in_keys=["y"], out_keys=["z"]),
... )
>>> td = TensorDict({"x": torch.zeros(())}, [])
>>> seq(td)
TensorDict(
    fields={
        x: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        y: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        z: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
>>> seq.select_out_keys("z")
>>> td = TensorDict({"x": torch.zeros(())}, [])
>>> seq(td)
TensorDict(
    fields={
        x: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        z: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

© 版權所有 2022, Meta。

使用 Sphinx 構建,主題由 Read the Docs 提供。

文件

訪問 PyTorch 的全面開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源