CatFrames¶
- class torchrl.envs.transforms.CatFrames(N: int, dim: int, in_keys: Sequence[NestedKey] | None = None, out_keys: Sequence[NestedKey] | None = None, padding='same', padding_value=0, as_inverse=False, reset_key: NestedKey | None = None, done_key: NestedKey | None = None)[原始碼]¶
將連續的觀察幀串聯成單個張量。
此 transform 對於在觀察到的特徵中建立運動感或速度感很有用。它也可以與需要訪問過去觀察的模型(如 transformers 等)一起使用。它最初在 “Playing Atari with Deep Reinforcement Learning” (https://arxiv.org/pdf/1312.5602.pdf) 中提出。
當在轉換後的環境中使用時,
CatFrames是一個有狀態的類,可以透過呼叫reset()方法將其重置為原生狀態。此方法接受包含"_reset"條目的 tensordict,該條目指示要重置哪個緩衝區。- 引數:
N (int) – 要串聯的觀察幀數。
dim (int) – 串聯觀察值的維度。應為負數,以確保其與不同 batch_size 的環境相容。
in_keys (sequence of NestedKey, optional) – 指向需要串聯的幀的鍵。預設為 [“pixels”]。
out_keys (sequence of NestedKey, optional) – 指向輸出寫入位置的鍵。預設為 in_keys 的值。
padding (str, optional) – 填充方法。可以是
"same"或"constant"。預設為"same",即使用第一個值進行填充。padding_value (
float, optional) – 如果padding="constant",用於填充的值。預設為 0。as_inverse (bool, optional) – 如果為
True,則應用 inverse transform。預設為False。reset_key (NestedKey, optional) – 用作部分重置指示符的 reset 鍵。必須是唯一的。如果未提供,則預設為父環境唯一的 reset 鍵(如果只有一個),否則會引發異常。
done_key (NestedKey, optional) – 用作部分 done 指示符的 done 鍵。必須是唯一的。如果未提供,則預設為
"done"。
示例
>>> from torchrl.envs.libs.gym import GymEnv >>> env = TransformedEnv(GymEnv('Pendulum-v1'), ... Compose( ... UnsqueezeTransform(-1, in_keys=["observation"]), ... CatFrames(N=4, dim=-1, in_keys=["observation"]), ... ) ... ) >>> print(env.rollout(3))
CatFramestransform 也可以離線使用,以在不同規模下重現線上幀串聯的效果(或為了限制記憶體消耗)。以下示例給出了完整的說明,以及torchrl.data.ReplayBuffer的用法示例
>>> from torchrl.envs.utils import RandomPolicy >>> from torchrl.envs import UnsqueezeTransform, CatFrames >>> from torchrl.collectors import SyncDataCollector >>> # Create a transformed environment with CatFrames: notice the usage of UnsqueezeTransform to create an extra dimension >>> env = TransformedEnv( ... GymEnv("CartPole-v1", from_pixels=True), ... Compose( ... ToTensorImage(in_keys=["pixels"], out_keys=["pixels_trsf"]), ... Resize(in_keys=["pixels_trsf"], w=64, h=64), ... GrayScale(in_keys=["pixels_trsf"]), ... UnsqueezeTransform(-4, in_keys=["pixels_trsf"]), ... CatFrames(dim=-4, N=4, in_keys=["pixels_trsf"]), ... ) ... ) >>> # we design a collector >>> collector = SyncDataCollector( ... env, ... RandomPolicy(env.action_spec), ... frames_per_batch=10, ... total_frames=1000, ... ) >>> for data in collector: ... print(data) ... break >>> # now let's create a transform for the replay buffer. We don't need to unsqueeze the data here. >>> # however, we need to point to both the pixel entry at the root and at the next levels: >>> t = Compose( ... ToTensorImage(in_keys=["pixels", ("next", "pixels")], out_keys=["pixels_trsf", ("next", "pixels_trsf")]), ... Resize(in_keys=["pixels_trsf", ("next", "pixels_trsf")], w=64, h=64), ... GrayScale(in_keys=["pixels_trsf", ("next", "pixels_trsf")]), ... CatFrames(dim=-4, N=4, in_keys=["pixels_trsf", ("next", "pixels_trsf")]), ... ) >>> from torchrl.data import TensorDictReplayBuffer, LazyMemmapStorage >>> rb = TensorDictReplayBuffer(storage=LazyMemmapStorage(1000), transform=t, batch_size=16) >>> data_exclude = data.exclude("pixels_trsf", ("next", "pixels_trsf")) >>> rb.add(data_exclude) >>> s = rb.sample(1) # the buffer has only one element >>> # let's check that our sample is the same as the batch collected during inference >>> assert (data.exclude("collector")==s.squeeze(0).exclude("index", "collector")).all()
注意
CatFrames目前僅支援根級別的"done"訊號。目前不支援巢狀的done,例如在 MARL 設定中發現的那些。如果需要此功能,請在 TorchRL 倉庫上提出一個 issue。注意
在回放緩衝區中儲存幀堆疊會顯著增加記憶體消耗(增加 N 倍)。為了緩解這個問題,你可以直接在回放緩衝區中儲存軌跡,並在取樣時應用
CatFrames。此方法包括對儲存的軌跡進行切片取樣,然後應用幀堆疊 transform。為了方便起見,CatFrames提供了一個make_rb_transform_and_sampler()方法,該方法會建立適合在回放緩衝區中使用的 transform 的修改版本
一個用於緩衝區的相應
SliceSampler
- make_rb_transform_and_sampler(batch_size: int, **sampler_kwargs) Tuple[Transform, 'torchrl.data.replay_buffers.SliceSampler'][原始碼]¶
建立用於儲存幀堆疊資料時與回放緩衝區一起使用的 transform 和 sampler。
此方法透過避免在緩衝區中儲存整個幀堆疊來幫助減少儲存資料中的冗餘。它會建立一個在取樣時即時堆疊幀的 transform,以及一個確保維護正確序列長度的 sampler。
- 引數:
batch_size (int) – 用於 sampler 的 batch size。
**sampler_kwargs – 傳遞給
SliceSampler建構函式的附加關鍵字引數。
- 返回:
transform (Transform): 在取樣時即時堆疊幀的 transform。
sampler (SliceSampler): 確保維護正確序列長度的 sampler。
- 返回型別:
一個包含以下內容的元組
示例
>>> env = TransformedEnv(...) >>> catframes = CatFrames(N=4, ...) >>> transform, sampler = catframes.make_rb_transform_and_sampler(batch_size=32) >>> rb = ReplayBuffer(..., sampler=sampler, transform=transform)
注意
在使用影像時,建議在前面的
ToTensorImagetransform 中使用不同的in_keys和out_keys。這確保儲存在緩衝區中的張量與它們的處理後版本分開,我們不希望儲存處理後版本。對於非影像資料,考慮在CatFrames之前插入一個RenameTransform,以建立將在緩衝區中儲存的資料副本。注意
將 transform 新增到回放緩衝區時,應注意同時傳遞在 CatFrames 之前的 transform,例如
ToTensorImage或UnsqueezeTransform,以便CatFramestransform 看到的資料格式與資料收集時的格式相同。注意
有關更完整的示例,請參閱 torchrl 的 github 倉庫 examples 資料夾: https://github.com/pytorch/rl/tree/main/examples/replay-buffers/catframes-in-buffer.py
- transform_observation_spec(observation_spec: TensorSpec) TensorSpec[原始碼]¶
轉換觀察規範,使結果規範與 transform 對映匹配。
- 引數:
observation_spec (TensorSpec) – transform 之前的規範
- 返回:
transform 之後的預期規範