ValueOperator¶
- class torchrl.modules.tensordict_module.ValueOperator(*args, **kwargs)[source]¶
強化學習中值函式的通用類。
ValueOperator 類為 in_keys 和 out_keys 引數提供了預設值(分別為 [“observation”] 以及 [“state_value”] 或 [“state_action_value”],具體取決於 “action” 鍵是否包含在 in_keys 列表中)。
- 引數:
module (nn.Module) – 一個
torch.nn.Module,用於將輸入對映到輸出引數空間。in_keys (str 的可迭代物件, 可選) – 從輸入 tensordict 讀取並傳遞給模組的鍵。如果包含多個元素,則值將按照 in_keys 可迭代物件給定的順序傳遞。預設為
["observation"]。out_keys (str 的可迭代物件) – 要寫入輸入 tensordict 的鍵。out_keys 的長度必須與嵌入模組返回的張量數量匹配。使用 “_” 作為鍵可避免將張量寫入輸出。如果
"action"是in_keys的一部分,則預設為["state_value"]或["state_action_value"]。
示例
>>> import torch >>> from tensordict import TensorDict >>> from torch import nn >>> from torchrl.data import Unbounded >>> from torchrl.modules import ValueOperator >>> td = TensorDict({"observation": torch.randn(3, 4), "action": torch.randn(3, 2)}, [3,]) >>> class CustomModule(nn.Module): ... def __init__(self): ... super().__init__() ... self.linear = torch.nn.Linear(6, 1) ... def forward(self, obs, action): ... return self.linear(torch.cat([obs, action], -1)) >>> module = CustomModule() >>> td_module = ValueOperator( ... in_keys=["observation", "action"], module=module ... ) >>> td = td_module(td) >>> print(td) TensorDict( fields={ action: Tensor(shape=torch.Size([3, 2]), device=cpu, dtype=torch.float32, is_shared=False), observation: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False), state_action_value: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([3]), device=None, is_shared=False)