快捷方式

tensordict.nn.dispatch

class tensordict.nn.dispatch(separator='_', source='in_keys', dest='out_keys', auto_batch_size: bool = True)

允許使用 kwargs 呼叫預期接收 TensorDict 的函式。

dispatch() 必須在具有 in_keys(或由 source 關鍵字引數指示的另一個鍵源)和 out_keys(或另一個 dest 鍵列表)屬性的模組中使用,這些屬性指示要從 tensordict 讀取和寫入哪些鍵。被包裝的函式也應該有一個 tensordict 開頭引數。

如果 out_keys 中只有一個元素,則結果函式將返回一個張量;否則,它將返回一個按模組的 out_keys 排序的元組。

當需要傳遞額外引數時,dispatch() 可以用作方法或類。

引數:
  • separator (str, 可選) – 用於組合 in_keys 中的子鍵的分隔符,這些子鍵是字串元組。預設為 "_"

  • source (str鍵列表, 可選) – 如果提供字串,則指向包含要使用的輸入鍵列表的模組屬性。如果改為提供列表,則將包含用作模組輸入的鍵。預設為 "in_keys",這是 TensorDictModule 的輸入鍵列表的屬性名稱。

  • dest (str鍵列表, 可選) – 如果提供字串,則指向包含要使用的輸出鍵列表的模組屬性。如果改為提供列表,則將包含用作模組輸出的鍵。預設為 "out_keys",這是 TensorDictModule 的輸出鍵列表的屬性名稱。

  • auto_batch_size (bool, 可選) – 如果為 True,則輸入的 tensordict 的批大小會自動確定為所有輸入張量中最大數量的共同維度。預設為 True

示例

>>> class MyModule(nn.Module):
...     in_keys = ["a"]
...     out_keys = ["b"]
...
...     @dispatch
...     def forward(self, tensordict):
...         tensordict['b'] = tensordict['a'] + 1
...         return tensordict
...
>>> module = MyModule()
>>> b = module(a=torch.zeros(1, 2))
>>> assert (b == 1).all()
>>> # equivalently
>>> class MyModule(nn.Module):
...     keys_in = ["a"]
...     keys_out = ["b"]
...
...     @dispatch(source="keys_in", dest="keys_out")
...     def forward(self, tensordict):
...         tensordict['b'] = tensordict['a'] + 1
...         return tensordict
...
>>> module = MyModule()
>>> b = module(a=torch.zeros(1, 2))
>>> assert (b == 1).all()
>>> # or this
>>> class MyModule(nn.Module):
...     @dispatch(source=["a"], dest=["b"])
...     def forward(self, tensordict):
...         tensordict['b'] = tensordict['a'] + 1
...         return tensordict
...
>>> module = MyModule()
>>> b = module(a=torch.zeros(1, 2))
>>> assert (b == 1).all()

dispatch_kwargs() 也適用於使用預設分隔符 "_" 的巢狀鍵。

示例

>>> class MyModuleNest(nn.Module):
...     in_keys = [("a", "c")]
...     out_keys = ["b"]
...
...     @dispatch
...     def forward(self, tensordict):
...         tensordict['b'] = tensordict['a', 'c'] + 1
...         return tensordict
...
>>> module = MyModuleNest()
>>> b, = module(a_c=torch.zeros(1, 2))
>>> assert (b == 1).all()

如果需要使用其他分隔符,可以在建構函式中使用 separator 引數指定。

示例

>>> class MyModuleNest(nn.Module):
...     in_keys = [("a", "c")]
...     out_keys = ["b"]
...
...     @dispatch(separator="sep")
...     def forward(self, tensordict):
...         tensordict['b'] = tensordict['a', 'c'] + 1
...         return tensordict
...
>>> module = MyModuleNest()
>>> b, = module(asepc=torch.zeros(1, 2))
>>> assert (b == 1).all()

由於輸入鍵是按順序排列的字串序列,dispatch() 也可以與未命名引數一起使用,此時引數順序必須與輸入鍵的順序匹配。

注意

如果第一個引數是 TensorDictBase 例項,則假定沒有使用 dispatch,並且該 tensordict 包含透過模組執行所需的所有必要資訊。換句話說,不能使用模組輸入的第一個鍵指向 tensordict 例項來分解 tensordict。通常,建議只對 tensordict 的葉子節點使用 dispatch()

示例

>>> class MyModuleNest(nn.Module):
...     in_keys = [("a", "c"), "d"]
...     out_keys = ["b"]
...
...     @dispatch
...     def forward(self, tensordict):
...         tensordict['b'] = tensordict['a', 'c'] + tensordict["d"]
...         return tensordict
...
>>> module = MyModuleNest()
>>> b, = module(torch.zeros(1, 2), d=torch.ones(1, 2))  # works
>>> assert (b == 1).all()
>>> b, = module(torch.zeros(1, 2), torch.ones(1, 2))  # works
>>> assert (b == 1).all()
>>> try:
...     b, = module(torch.zeros(1, 2), a_c=torch.ones(1, 2))  # fails
... except:
...     print("oopsy!")
...

文件

查閱 PyTorch 的完整開發者文件

檢視文件

教程

獲取針對初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源