快捷方式

VmapModule

class torchrl.modules.VmapModule(*args, **kwargs)[source]

一個 TensorDictModule 包裝器,用於對輸入進行 vmap 操作。

它旨在與接受比提供的資料少一個批處理維度的模組一起使用。透過使用此包裝器,可以隱藏一個批處理維度並滿足被包裝模組的要求。

引數:
  • module (TensorDictModuleBase) – 要進行 vmap 操作的模組。

  • vmap_dim (int, optional) – vmap 輸入和輸出維度。如果未提供,則假定為 tensordict 的最後一個維度。

注意

由於 vmap 要求控制輸入的批大小,此模組不支援分派的引數

示例

>>> lam = TensorDictModule(lambda x: x[0], in_keys=["x"], out_keys=["y"])
>>> sample_in = torch.ones((10,3,2))
>>> sample_in_td = TensorDict({"x":sample_in}, batch_size=[10])
>>> lam(sample_in)
>>> vm = VmapModule(lam, 0)
>>> vm(sample_in_td)
>>> assert (sample_in_td["x"][:, 0] == sample_in_td["y"]).all()
forward(tensordict)[source]

定義每次呼叫時執行的計算。

應由所有子類覆蓋。

注意

儘管需要在該函式內部定義前向傳播的實現方式,但之後應呼叫 Module 例項而不是該函式本身,因為前者負責執行已註冊的鉤子,而後者會默默忽略它們。

文件

查閱 PyTorch 的全面開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲得問題的解答

檢視資源