DoubleToFloat¶
- class torchrl.envs.transforms.DoubleToFloat(in_keys: Sequence[NestedKey] | None = None, out_keys: Sequence[NestedKey] | None = None, in_keys_inv: Sequence[NestedKey] | None = None, out_keys_inv: Sequence[NestedKey] | None = None)[source]¶
將選定鍵的一個數據型別轉換為另一個。
根據構造時是否提供了
in_keys或in_keys_inv,類行為將發生變化如果提供了鍵,則僅這些條目將從
float64轉換為float32條目;如果未提供鍵且物件位於環境的 transforms 登錄檔中,則 dtype 設定為
float64的輸入和輸出 specs 將分別用作 in_keys_inv / in_keys。如果未提供鍵且物件在沒有環境的情況下使用,則
forward/inverse傳遞將掃描輸入 tensordict 中的所有 float64 值並將它們對映到 float32 張量。對於大型資料結構,這會影響效能,因為掃描並非沒有開銷。要轉換的鍵將不會被快取。請注意,在這種情況下,out_keys(或 out_keys_inv)無法傳遞,因為無法精確預測鍵的處理順序。
- 引數:
in_keys (sequence of NestedKey, optional) – 要在暴露給外部物件和函式之前轉換為 float 的 double 鍵列表。
out_keys (sequence of NestedKey, optional) – 目標鍵列表。如果未提供,預設為
in_keys。in_keys_inv (sequence of NestedKey, optional) – 在傳遞給包含的 base_env 或儲存之前要轉換為 double 的 float 鍵列表。
out_keys_inv (sequence of NestedKey, optional) – 用於逆轉換的目標鍵列表。如果未提供,預設為
in_keys_inv。
示例
>>> td = TensorDict( ... {'obs': torch.ones(1, dtype=torch.double), ... 'not_transformed': torch.ones(1, dtype=torch.double), ... }, []) >>> transform = DoubleToFloat(in_keys=["obs"]) >>> _ = transform(td) >>> print(td.get("obs").dtype) torch.float32 >>> print(td.get("not_transformed").dtype) torch.float64
在“自動”模式下,所有 float64 條目都會被轉換
示例
>>> td = TensorDict( ... {'obs': torch.ones(1, dtype=torch.double), ... 'not_transformed': torch.ones(1, dtype=torch.double), ... }, []) >>> transform = DoubleToFloat() >>> _ = transform(td) >>> print(td.get("obs").dtype) torch.float32 >>> print(td.get("not_transformed").dtype) torch.float32
在構造環境時未指定 transform 鍵時,規則也是相同的行為
示例
>>> class MyEnv(EnvBase): ... def __init__(self): ... super().__init__() ... self.observation_spec = Composite(obs=Unbounded((), dtype=torch.float64)) ... self.action_spec = Unbounded((), dtype=torch.float64) ... self.reward_spec = Unbounded((1,), dtype=torch.float64) ... self.done_spec = Unbounded((1,), dtype=torch.bool) ... def _reset(self, data=None): ... return TensorDict({"done": torch.zeros((1,), dtype=torch.bool), **self.observation_spec.rand()}, []) ... def _step(self, data): ... assert data["action"].dtype == torch.float64 ... reward = self.reward_spec.rand() ... done = torch.zeros((1,), dtype=torch.bool) ... obs = self.observation_spec.rand() ... assert reward.dtype == torch.float64 ... assert obs["obs"].dtype == torch.float64 ... return obs.empty().set("next", obs.update({"reward": reward, "done": done})) ... def _set_seed(self, seed): ... pass >>> env = TransformedEnv(MyEnv(), DoubleToFloat()) >>> assert env.action_spec.dtype == torch.float32 >>> assert env.observation_spec["obs"].dtype == torch.float32 >>> assert env.reward_spec.dtype == torch.float32, env.reward_spec.dtype >>> print(env.rollout(2)) TensorDict( fields={ action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False), done: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False), next: TensorDict( fields={ done: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False), obs: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False), reward: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([2]), device=cpu, is_shared=False), obs: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([2]), device=cpu, is_shared=False) >>> assert env.transform.in_keys == ["obs", "reward"] >>> assert env.transform.in_keys_inv == ["action"]