TrajCounter¶
- class torchrl.envs.transforms.TrajCounter(out_key: NestedKey = 'traj_count')[source]¶
全域性軌跡計數器變換。
TrajCounter 可用於計算任何 TorchRL 環境中的軌跡數量(即呼叫 reset 的次數)。此變換可在單個節點內的多個程序中工作(見下注)。單個變換隻能計算與單個完成狀態相關的軌跡,但只要巢狀完成狀態的字首與計數器鍵的字首匹配,就可以接受巢狀完成狀態。
- 引數:
out_key (NestedKey, optional) – 軌跡計數器的條目名稱。預設為
"traj_count"。
示例
>>> from torchrl.envs import GymEnv, StepCounter, TrajCounter >>> env = GymEnv("Pendulum-v1").append_transform(StepCounter(6)) >>> env = env.append_transform(TrajCounter()) >>> r = env.rollout(18, break_when_any_done=False) # 18 // 6 = 3 trajectories >>> r["next", "traj_count"] tensor([[0], [0], [0], [0], [0], [0], [1], [1], [1], [1], [1], [1], [2], [2], [2], [2], [2], [2]])
注意
在 workers 之間共享軌跡計數器可以透過多種方式實現,但這通常涉及將環境封裝在
EnvCreator中。不這樣做可能會在變換序列化期間導致錯誤。計數器將在 workers 之間共享,這意味著在任何時間點,都可以保證不會有兩個環境共享相同的軌跡計數(並且每個 (步數, 軌跡數) 對都將是唯一的)。以下是跨程序共享TrajCounter物件的有效方法示例>>> # Option 1: Create the trajectory counter outside the environment. >>> # This requires the counter to be cloned within the transformed env, as a single transform object cannot have two parents. >>> t = TrajCounter() >>> def make_env(max_steps=4, t=t): ... # See CountingEnv in torchrl.test.mocking_classes ... env = TransformedEnv(CountingEnv(max_steps=max_steps), t.clone()) ... env.transform.transform_observation_spec(env.base_env.observation_spec) ... return env >>> penv = ParallelEnv( ... 2, ... [EnvCreator(make_env, max_steps=4), EnvCreator(make_env, max_steps=5)], ... mp_start_method="spawn", ... ) >>> # Option 2: Create the transform within the constructor. >>> # In this scenario, we still need to tell each sub-env what kwarg has to be used. >>> # Both EnvCreator and ParallelEnv offer that possibility. >>> def make_env(max_steps=4): ... t = TrajCounter() ... env = TransformedEnv(CountingEnv(max_steps=max_steps), t) ... env.transform.transform_observation_spec(env.base_env.observation_spec) ... return env >>> make_env_c0 = EnvCreator(make_env) >>> # Create a variant of the env with different kwargs >>> make_env_c1 = make_env_c0.make_variant(max_steps=5) >>> penv = ParallelEnv( ... 2, ... [make_env_c0, make_env_c1], ... mp_start_method="spawn", ... ) >>> # Alternatively, pass the kwargs to the ParallelEnv >>> penv = ParallelEnv( ... 2, ... [make_env_c0, make_env_c0], ... create_env_kwargs=[{"max_steps": 5}, {"max_steps": 4}], ... mp_start_method="spawn", ... )
- load_state_dict(state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False)[source]¶
將
state_dict中的引數和緩衝區複製到此模組及其後代中。如果
strict為True,則state_dict的鍵必須與此模組的state_dict()函式返回的鍵完全匹配。警告
如果
assign為True,則除非get_swap_module_params_on_conversion()為True,否則必須在呼叫load_state_dict後建立最佳化器。- 引數:
state_dict (dict) – 包含引數和永續性緩衝區的字典。
strict (bool, optional) – 是否嚴格執行
state_dict中的鍵與此模組的state_dict()函式返回的鍵匹配。預設為Trueassign (bool, optional) – 當設定為
False時,保留當前模組中張量的屬性,而設定為True時保留 state dict 中張量的屬性。唯一的例外是requires_grad欄位。預設為False
- 返回:
- missing_keys 是一個字串列表,包含預期但
提供的
state_dict中缺失的鍵。
- unexpected_keys 是一個字串列表,包含此模組
未預期但在提供的
state_dict中存在的鍵。
- 返回型別:
包含
missing_keys和unexpected_keys欄位的NamedTuple
注意
如果引數或緩衝區註冊為
None且其對應的鍵存在於state_dict中,load_state_dict()將引發RuntimeError。
- state_dict(*args, destination=None, prefix='', keep_vars=False)[source]¶
返回一個包含模組完整狀態引用的字典。
包括引數和永續性緩衝區(例如,執行平均值)。鍵是對應的引數和緩衝區名稱。設定為
None的引數和緩衝區不包括在內。注意
返回的物件是淺複製。它包含對模組引數和緩衝區的引用。
警告
目前
state_dict()也按順序接受 destination、prefix 和 keep_vars 的位置引數。但是,此用法已被棄用,未來版本將強制使用關鍵字引數。警告
請避免使用引數
destination,因為它不面向終端使用者設計。- 引數:
destination (dict, optional) – 如果提供,模組狀態將更新到該字典中並返回同一物件。否則,將建立並返回一個
OrderedDict。預設為None。prefix (str, optional) – 新增到引數和緩衝區名稱前的字首,用於在 state_dict 中構成鍵。預設為
''。keep_vars (bool, optional) – 預設情況下,state dict 中返回的
Tensor與 autograd 分離。如果設定為True,則不會進行分離。預設為False。
- 返回:
包含模組完整狀態的字典
- 返回型別:
dict
示例
>>> # xdoctest: +SKIP("undefined vars") >>> module.state_dict().keys() ['bias', 'weight']
- transform_observation_spec(observation_spec: Composite) Composite[source]¶
變換觀測規範,使結果規範與變換對映匹配。
- 引數:
observation_spec (TensorSpec) – 變換前的規範
- 返回:
變換後預期的規範