step_mdp¶
- torchrl.envs.utils.step_mdp(tensordict: TensorDictBase, next_tensordict: TensorDictBase = None, keep_other: bool = True, exclude_reward: bool = True, exclude_done: bool = False, exclude_action: bool = True, reward_keys: NestedKey | list[NestedKey] = 'reward', done_keys: NestedKey | list[NestedKey] = 'done', action_keys: NestedKey | list[NestedKey] = 'action') TensorDictBase[原始碼]¶
建立一個新的 TensorDict,反映輸入 TensorDict 的時間步進。
給定一個在步進後獲取的 TensorDict,返回帶
"next"索引的 TensorDict。引數允許精確控制應保留哪些內容以及應從"next"條目中複製哪些內容。預設行為是:將觀測條目、獎勵和完成狀態移動到根目錄,排除當前動作,並保留所有額外部索引鍵(非動作、非完成、非獎勵)。- 引數:
tensordict (TensorDictBase) – 包含要重新命名鍵的 TensorDict。
next_tensordict (TensorDictBase, optional) – 目標 TensorDict。如果為 None,則建立一個新的 TensorDict。
keep_other (bool, optional) – 如果為
True,所有不以'next_'開頭的鍵都將被保留。預設為True。exclude_reward (bool, optional) – 如果為
True,"reward"鍵將從結果 TensorDict 中丟棄。如果為False,它將從"next"條目中複製(並替換)(如果存在)。預設為True。exclude_done (bool, optional) – 如果為
True,"done"鍵將從結果 TensorDict 中丟棄。如果為False,它將從"next"條目中複製(並替換)(如果存在)。預設為False。exclude_action (bool, optional) – 如果為
True,"action"鍵將從結果 TensorDict 中丟棄。如果為False,它將保留在根 TensorDict 中(因為它不應該存在於"next"條目中)。預設為True。reward_keys (NestedKey 或 NestedKey 列表, optional) – 寫入獎勵的鍵。預設為 “reward”。
done_keys (NestedKey 或 NestedKey 列表, optional) – 寫入完成狀態的鍵。預設為 “done”。
action_keys (NestedKey 或 NestedKey 列表, optional) – 寫入動作的鍵。預設為 “action”。
- 返回:
一個新的 TensorDict(如果提供了 next_tensordict,則為該 TensorDict),包含 t+1 步的張量。
- 返回型別:
TensorDictBase
另請參閱
EnvBase.step_mdp()是此自由函式的基於類別的版本。它將嘗試快取鍵值以減少在 MDP 中執行一步的開銷。示例
>>> from tensordict import TensorDict >>> import torch >>> td = TensorDict({ ... "done": torch.zeros((), dtype=torch.bool), ... "reward": torch.zeros(()), ... "extra": torch.zeros(()), ... "next": TensorDict({ ... "done": torch.zeros((), dtype=torch.bool), ... "reward": torch.zeros(()), ... "obs": torch.zeros(()), ... }, []), ... "obs": torch.zeros(()), ... "action": torch.zeros(()), ... }, []) >>> print(step_mdp(td)) TensorDict( fields={ done: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.bool, is_shared=False), extra: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), obs: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) >>> print(step_mdp(td, exclude_done=True)) # "done" is dropped TensorDict( fields={ extra: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), obs: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) >>> print(step_mdp(td, exclude_reward=False)) # "reward" is kept TensorDict( fields={ done: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.bool, is_shared=False), extra: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), obs: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), reward: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) >>> print(step_mdp(td, exclude_action=False)) # "action" persists at the root TensorDict( fields={ action: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), done: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.bool, is_shared=False), extra: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), obs: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) >>> print(step_mdp(td, keep_other=False)) # "extra" is missing TensorDict( fields={ done: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.bool, is_shared=False), obs: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)
警告
如果獎勵鍵在被排除時也是輸入鍵的一部分,此函式將無法正常工作。這就是為什麼
RewardSum變換預設將回合獎勵註冊到觀測中而不是獎勵規範中。使用此函式的快速快取版本 (_StepMDP) 時,不應出現此問題。