快捷方式

SafeProbabilisticTensorDictSequential

class torchrl.modules.tensordict_module.SafeProbabilisticTensorDictSequential(*args, **kwargs)[原始碼]

tensordict.nn.ProbabilisticTensorDictSequential 的子類,接受 `TensorSpec` 引數來控制輸出域。

TensorDictSequential 類似,但強制要求序列中的最後一個模組是 ProbabilisticTensorDictModule,並暴露 `get_dist` 方法以便從 ProbabilisticTensorDictModule 中恢復分佈物件。

引數::
  • modules (TensorDictModules 可迭代物件) – TensorDictModule 例項的有序序列,以 ProbabilisticTensorDictModule 結尾,按順序執行。

  • partial_tolerant (bool, 可選) – 如果為 `True`,輸入 tensordict 可以缺少一些輸入鍵。在這種情況下,只有那些在現有鍵下可以執行的模組會被執行。此外,如果輸入 tensordict 是 tensordict 的惰性堆疊 *並且* partial_tolerant 為 `True` *並且* 堆疊沒有所需的鍵,那麼 TensorDictSequential 將掃描子 tensordict,查詢包含所需鍵的那些(如果存在)。

文件

查閱 PyTorch 的全面開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源