VecNorm¶
- class torchrl.envs.transforms.VecNorm(in_keys: Sequence[NestedKey] | None = None, out_keys: Sequence[NestedKey] | None = None, shared_td: Optional[TensorDictBase] = None, lock: mp.Lock = None, decay: float = 0.9999, eps: float = 0.0001, shapes: List[torch.Size] = None)[source]¶
用於 torchrl 環境的移動平均歸一化層。
VecNorm 跟蹤資料集的彙總統計資訊,以便進行即時標準化。如果變換處於“eval”模式,則不會更新執行中的統計資訊。
如果多個程序正在執行類似的環境,可以傳遞一個放置在共享記憶體中的 TensorDictBase 例項:在這種情況下,每次查詢歸一化層時,都會更新共享同一引用的所有程序的值。
為了在推理時使用 VecNorm 並避免使用新觀測值更新值,應將此層替換為
to_observation_norm()。這將提供一個靜態版本的 VecNorm,當源變換更新時,它不會被更新。要獲取 VecNorm 層的凍結副本,請參閱frozen_copy()。- 引數:
in_keys (NestedKey 的序列, 可選) – 要更新的鍵。預設值: [“observation”, “reward”]
out_keys (NestedKey 的序列, 可選) – 目標鍵。預設值為
in_keys。shared_td (TensorDictBase, 可選) – 包含變換鍵的共享 tensordict。
lock (mp.Lock) – 用於防止程序之間出現競爭條件的鎖。預設值為 None(在初始化期間建立鎖)。
decay (數字, 可選) – 移動平均的衰減率。預設值: 0.99
eps (數字, 可選) – 執行標準差的下界(用於防止數值下溢)。預設值為 1e-4。
shapes (List[torch.Size], 可選) – 如果提供,表示每個 in_keys 的形狀。其長度必須與
in_keys的長度匹配。每個形狀必須與對應條目的尾部維度匹配。否則,條目的特徵維度(即不屬於 tensordict batch-size 的所有維度)將被視為特徵維度。
示例
>>> from torchrl.envs.libs.gym import GymEnv >>> t = VecNorm(decay=0.9) >>> env = GymEnv("Pendulum-v0") >>> env = TransformedEnv(env, t) >>> tds = [] >>> for _ in range(1000): ... td = env.rand_step() ... if td.get("done"): ... _ = env.reset() ... tds += [td] >>> tds = torch.stack(tds, 0) >>> print((abs(tds.get(("next", "observation")).mean(0))<0.2).all()) tensor(True) >>> print((abs(tds.get(("next", "observation")).std(0)-1)<0.2).all()) tensor(True)
建立用於跨程序歸一化的共享 tensordict。
- 引數:
env (EnvBase) – 用於建立 tensordict 的示例環境
keys (NestedKey 的序列, 可選) – 需要歸一化的鍵。預設值為 [“next”, “reward”]
memmap (bool) – 如果為
True,生成的 tensordict 將被轉換為記憶體對映(使用 memmap_())。否則,tensordict 將被放置在共享記憶體中。
- 返回值:
一個共享記憶體區域,用於傳送給每個程序。
示例
>>> from torch import multiprocessing as mp >>> queue = mp.Queue() >>> env = make_env() >>> td_shared = VecNorm.build_td_for_shared_vecnorm(env, ... ["next", "reward"]) >>> assert td_shared.is_shared() >>> queue.put(td_shared) >>> # on workers >>> v = VecNorm(shared_td=queue.get()) >>> env = TransformedEnv(make_env(), v)
- forward(tensordict: TensorDictBase) TensorDictBase¶
讀取輸入的 tensordict,並對選定的鍵應用變換。
- freeze() VecNorm[source]¶
凍結 VecNorm,呼叫時避免更新統計資訊。
參見
unfreeze()。
- get_extra_state() OrderedDict[source]¶
返回要包含在模組 state_dict 中的任何額外狀態。
如果您需要儲存額外狀態,請為您的模組實現此函式及相應的
set_extra_state()。構建模組的 state_dict() 時會呼叫此函式。請注意,額外狀態應可被 pickle 序列化,以確保 state_dict 的序列化正常工作。我們僅為 Tensor 的序列化提供向後相容性保證;如果其他物件的序列化 pickle 形式發生變化,可能會破壞向後相容性。
- 返回值:
要儲存在模組 state_dict 中的任何額外狀態
- 返回型別:
object
- property loc¶
返回一個包含用於仿射變換的 loc 的 TensorDict。
- property scale¶
返回一個包含用於仿射變換的 scale 的 TensorDict。
- set_extra_state(state: OrderedDict) None[source]¶
設定載入的 state_dict 中包含的額外狀態。
此函式由
load_state_dict()呼叫,用於處理在 state_dict 中找到的任何額外狀態。如果您需要在模組的 state_dict 中儲存額外狀態,請實現此函式及相應的get_extra_state()。- 引數:
state (dict) – 來自 state_dict 的額外狀態
- property standard_normal¶
給定 loc 和 scale 的仿射變換是否遵循標準正態方程。
類似於
ObservationNorm的 standard_normal 屬性。始終返回
True。
- to_observation_norm() Union[Compose, ObservationNorm][source]¶
將 VecNorm 轉換為可在推理時使用的 ObservationNorm 類。
ObservationNorm層可以使用state_dict()API 進行更新。示例
>>> from torchrl.envs import GymEnv, VecNorm >>> vecnorm = VecNorm(in_keys=["observation"]) >>> train_env = GymEnv("CartPole-v1", device=None).append_transform( ... vecnorm) >>> >>> r = train_env.rollout(4) >>> >>> eval_env = GymEnv("CartPole-v1").append_transform( ... vecnorm.to_observation_norm()) >>> print(eval_env.transform.loc, eval_env.transform.scale) >>> >>> r = train_env.rollout(4) >>> # Update entries with state_dict >>> eval_env.transform.load_state_dict( ... vecnorm.to_observation_norm().state_dict()) >>> print(eval_env.transform.loc, eval_env.transform.scale)
- transform_observation_spec(observation_spec: TensorSpec) TensorSpec[source]¶
轉換觀測規範,使結果規範與變換對映匹配。
- 引數:
observation_spec (TensorSpec) – 變換前的規範
- 返回值:
變換後預期的規範