快捷方式

EnvCreator

torchrl.envs.EnvCreator(create_env_fn: Callable[[...], EnvBase], create_env_kwargs: Optional[Dict] = None, share_memory: bool = True, **kwargs)[原始碼]

環境建立器類。

EnvCreator 是一個通用的環境建立器類,可在多程序環境下建立環境時替代 lambda 函式。如果在子程序中建立的環境需要與主程序共享資訊(例如用於 VecNorm 變換),EnvCreator 會將共享記憶體中 tensordict 的指標傳遞給每個程序,以便它們保持同步。

引數:
  • create_env_fn (callable) – 一個可呼叫物件,返回一個 EnvBase 例項。

  • create_env_kwargs (dict, optional) – 環境建立器的 kwargs。

  • share_memory (bool, optional) – 如果為 False,環境生成的 tensordict 不會放在共享記憶體中。

  • **kwargs – 在構建環境時要傳遞的額外關鍵字引數。

示例

>>> # We create the same environment on 2 processes using VecNorm
>>> # and check that the discounted count of observations match on
>>> # both workers, even if one has not executed any step
>>> import time
>>> from torchrl.envs.libs.gym import GymEnv
>>> from torchrl.envs.transforms import VecNorm, TransformedEnv
>>> from torchrl.envs import EnvCreator
>>> from torch import multiprocessing as mp
>>> env_fn = lambda: TransformedEnv(GymEnv("Pendulum-v1"), VecNorm())
>>> env_creator = EnvCreator(env_fn)
>>>
>>> def test_env1(env_creator):
...     env = env_creator()
...     tensordict = env.reset()
...     for _ in range(10):
...         env.rand_step(tensordict)
...         if tensordict.get(("next", "done")):
...             tensordict = env.reset(tensordict)
...     print("env 1: ", env.transform._td.get(("next", "observation_count")))
>>>
>>> def test_env2(env_creator):
...     env = env_creator()
...     time.sleep(5)
...     print("env 2: ", env.transform._td.get(("next", "observation_count")))
>>>
>>> if __name__ == "__main__":
...     ps = []
...     p1 = mp.Process(target=test_env1, args=(env_creator,))
...     p1.start()
...     ps.append(p1)
...     p2 = mp.Process(target=test_env2, args=(env_creator,))
...     p2.start()
...     ps.append(p1)
...     for p in ps:
...         p.join()
env 1:  tensor([11.9934])
env 2:  tensor([11.9934])
make_variant(**kwargs) EnvCreator[原始碼]

建立 EnvCreator 的一個變體,指向相同的底層元資料,但在構建時使用不同的關鍵字引數。

這對於共享狀態的變換 (transforms) 可能有用,例如 TrajCounter

示例

>>> from torchrl.envs import GymEnv
>>> env_creator_pendulum = EnvCreator(GymEnv, env_name="Pendulum-v1")
>>> env_creator_cartpole = env_creator_pendulum(env_name="CartPole-v1")

文件

查閱 PyTorch 的完整開發者文件

檢視文件

教程

獲取針對初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲得解答

檢視資源