快捷方式

SelectKeys

class torchrl.trainers.SelectKeys(keys: Sequence[str])[source]

選擇 TensorDict 批次中的鍵。

引數:

keys (字串可迭代物件) – 要在 tensordict 中選擇的鍵。

示例

>>> trainer = make_trainer()
>>> key1 = "first key"
>>> key2 = "second key"
>>> td = TensorDict(
...     {
...         key1: torch.randn(3),
...         key2: torch.randn(3),
...     },
...     [],
... )
>>> trainer.register_op("batch_process", SelectKeys([key1]))
>>> td_out = trainer._process_batch_hook(td)
>>> assert key1 in td_out.keys()
>>> assert key2 not in td_out.keys()
register(trainer, name='select_keys') None[source]

在訓練器中的預設位置註冊鉤子。

引數:
  • trainer (Trainer) – 必須註冊鉤子的訓練器。

  • name (str) – 鉤子的名稱。

注意

要在預設位置以外的其他位置註冊鉤子,請使用 register_op()

文件

查閱 PyTorch 的完整開發者文件

檢視文件

教程

獲取針對初學者和高階開發者的深入教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源