SafeProbabilisticTensorDictSequential¶
- class torchrl.modules.tensordict_module.SafeProbabilisticTensorDictSequential(*args, **kwargs)[原始碼]¶
是
tensordict.nn.ProbabilisticTensorDictSequential的子類,接受 `TensorSpec` 引數來控制輸出域。與
TensorDictSequential類似,但強制要求序列中的最後一個模組是ProbabilisticTensorDictModule,並暴露 `get_dist` 方法以便從ProbabilisticTensorDictModule中恢復分佈物件。- 引數::
modules (TensorDictModules 可迭代物件) – TensorDictModule 例項的有序序列,以 ProbabilisticTensorDictModule 結尾,按順序執行。
partial_tolerant (bool, 可選) – 如果為 `True`,輸入 tensordict 可以缺少一些輸入鍵。在這種情況下,只有那些在現有鍵下可以執行的模組會被執行。此外,如果輸入 tensordict 是 tensordict 的惰性堆疊 *並且* partial_tolerant 為 `True` *並且* 堆疊沒有所需的鍵,那麼 TensorDictSequential 將掃描子 tensordict,查詢包含所需鍵的那些(如果存在)。