AutoResetTransform¶
- class torchrl.envs.transforms.AutoResetTransform(*, replace: bool | None = None, fill_float='nan', fill_int=- 1, fill_bool=False)[原始碼]¶
一個用於自動重置環境的變換 (transform)。
這個變換 (transform) 可以附加到任何自動重置的環境上,或者使用
env = SomeEnvClass(..., auto_reset=True)自動附加。如果該變換被顯式地附加到一個環境上,則必須使用一個AutoResetEnv。一個自動重置的環境必須具備以下屬性 (與此描述不同的地方應透過繼承此類來處理):
reset 函式可以在開始時 (例項化之後) 呼叫一次,無論是否有效果。之後是否允許呼叫 reset 取決於環境本身。
在一次 Rollout 過程中,任何
done狀態都將導致重置,併產生一個觀察結果,該結果不是當前 episode 的最後一個觀察結果,而是下一個 episode 的第一個觀察結果 (此變換將提取並快取此觀察結果,並用某個任意值填充 obs)。
- 關鍵詞引數:
replace (bool, 可選) – 如果為
False,即使值無效,也會直接放在"next"條目中。預設為True。值為False會覆蓋後續任何填充關鍵詞引數。該引數也可以透過建構函式方法傳遞,透過傳遞auto_reset_replace引數實現:env = FooEnv(..., auto_reset=True, auto_reset_replace=False)。fill_float (
float或 str, 可選) – 用於填充結束 episode 的浮點張量的值。值為None意味著不替換 (即使值無效,也會直接放在"next"條目中)。fill_int (int, 可選) – 用於填充結束 episode 的有符號整數張量的值。值為
None意味著不替換 (即使值無效,也會直接放在"next"條目中)。fill_bool (bool, 可選) – 用於填充結束 episode 的布林張量的值。值為
None意味著不替換 (即使值無效,也會直接放在"next"條目中)。
這些引數僅在顯式例項化該變換時可用 (而不是透過 EnvType(…, auto_reset=True))。
示例
>>> from torchrl.envs import GymEnv >>> from torchrl.envs import set_gym_backend >>> import torch >>> torch.manual_seed(0) >>> >>> class AutoResettingGymEnv(GymEnv): ... def _step(self, tensordict): ... tensordict = super()._step(tensordict) ... if tensordict["done"].any(): ... td_reset = super().reset() ... tensordict.update(td_reset.exclude(*self.done_keys)) ... return tensordict ... ... def _reset(self, tensordict=None): ... if tensordict is not None and "_reset" in tensordict: ... return tensordict.copy() ... return super()._reset(tensordict) >>> >>> with set_gym_backend("gym"): ... env = AutoResettingGymEnv("CartPole-v1", auto_reset=True, auto_reset_replace=True) ... env.set_seed(0) ... r = env.rollout(30, break_when_any_done=False) >>> print(r["next", "done"].squeeze()) tensor([False, False, False, False, False, False, False, False, False, False, False, False, False, True, False, False, False, False, False, False, False, False, False, False, False, True, False, False, False, False]) >>> print("observation after reset are set as nan", r["next", "observation"]) observation after reset are set as nan tensor([[-4.3633e-02, -1.4877e-01, 1.2849e-02, 2.7584e-01], [-4.6609e-02, 4.6166e-02, 1.8366e-02, -1.2761e-02], [-4.5685e-02, 2.4102e-01, 1.8111e-02, -2.9959e-01], [-4.0865e-02, 4.5644e-02, 1.2119e-02, -1.2542e-03], [-3.9952e-02, 2.4059e-01, 1.2094e-02, -2.9009e-01], [-3.5140e-02, 4.3554e-01, 6.2920e-03, -5.7893e-01], [-2.6429e-02, 6.3057e-01, -5.2867e-03, -8.6963e-01], [-1.3818e-02, 8.2576e-01, -2.2679e-02, -1.1640e+00], [ 2.6972e-03, 1.0212e+00, -4.5959e-02, -1.4637e+00], [ 2.3121e-02, 1.2168e+00, -7.5232e-02, -1.7704e+00], [ 4.7457e-02, 1.4127e+00, -1.1064e-01, -2.0854e+00], [ 7.5712e-02, 1.2189e+00, -1.5235e-01, -1.8289e+00], [ 1.0009e-01, 1.0257e+00, -1.8893e-01, -1.5872e+00], [ nan, nan, nan, nan], [-3.9405e-02, -1.7766e-01, -1.0403e-02, 3.0626e-01], [-4.2959e-02, -3.7263e-01, -4.2775e-03, 5.9564e-01], [-5.0411e-02, -5.6769e-01, 7.6354e-03, 8.8698e-01], [-6.1765e-02, -7.6292e-01, 2.5375e-02, 1.1820e+00], [-7.7023e-02, -9.5836e-01, 4.9016e-02, 1.4826e+00], [-9.6191e-02, -7.6387e-01, 7.8667e-02, 1.2056e+00], [-1.1147e-01, -9.5991e-01, 1.0278e-01, 1.5219e+00], [-1.3067e-01, -7.6617e-01, 1.3322e-01, 1.2629e+00], [-1.4599e-01, -5.7298e-01, 1.5848e-01, 1.0148e+00], [-1.5745e-01, -7.6982e-01, 1.7877e-01, 1.3527e+00], [-1.7285e-01, -9.6668e-01, 2.0583e-01, 1.6956e+00], [ nan, nan, nan, nan], [-4.3962e-02, 1.9845e-01, -4.5015e-02, -2.5903e-01], [-3.9993e-02, 3.9418e-01, -5.0196e-02, -5.6557e-01], [-3.2109e-02, 5.8997e-01, -6.1507e-02, -8.7363e-01], [-2.0310e-02, 3.9574e-01, -7.8980e-02, -6.0090e-01]])