快捷方式

VideoRecorder

torchrl.record.VideoRecorder(logger: Logger, tag: str, in_keys: Optional[Sequence[NestedKey]] = None, skip: int | None = None, center_crop: Optional[int] = None, make_grid: bool | None = None, out_keys: Optional[Sequence[NestedKey]] = None, fps: int | None = None, **kwargs) None[source]

Video Recorder 變換。

將記錄來自環境的一系列觀測結果,並在需要時將其寫入 Logger 物件。

引數:
  • logger (Logger) – 影片應寫入的 Logger 例項。要將影片儲存為 memmap 張量或 mp4 檔案,請使用 CSVLogger 類。

  • tag (str) – 日誌記錄器中的影片標籤。

  • in_keys (Sequence of NestedKey, 可選) – 用於生成影片的讀取鍵。預設為 "pixels"

  • skip (int) – 輸出影片的幀間隔。如果變換具有父環境,則預設為 2,否則為 1

  • center_crop (int, 可選) – 中心方形裁剪的值。

  • make_grid (bool, 可選) – 如果為 True,則假設提供形狀為 [B x W x H x 3] 的張量,其中 B 是批次大小,將建立一個網格。如果變換具有父環境,則預設為 True,否則為 False

  • out_keys (sequence of NestedKey, 可選) – 目標鍵。如果未提供,則預設為 in_keys

  • fps (int, 可選) – 輸出影片的每秒幀數 (Frames per second)。預設為日誌記錄器預定義的 fps,如果提供,則覆蓋該值。

  • **kwargs (Dict[str, Any], 可選) – log_video() 的額外關鍵字引數。

示例

以下示例展示瞭如何在影片中儲存一次 rollout。首先匯入一些庫

>>> from torchrl.record import VideoRecorder
>>> from torchrl.record.loggers.csv import CSVLogger
>>> from torchrl.envs import TransformedEnv, DMControlEnv

影片格式在日誌記錄器中選擇。Wandb 和 tensorboard 會自行處理,CSV 接受各種影片格式。

>>> logger = CSVLogger(exp_name="cheetah", log_dir="cheetah_videos", video_format="mp4")

一些環境(例如,Atari 遊戲)原生返回影像,有些則需要使用者請求。檢視 GymEnvDMControlEnv 以瞭解如何在這些上下文中渲染影像。

>>> base_env = DMControlEnv("cheetah", "run", from_pixels=True)
>>> env = TransformedEnv(base_env, VideoRecorder(logger=logger, tag="run_video"))
>>> env.rollout(100)

所有 transforms 都有一個 dump 函式,大多數情況下是空操作 (no-op),除了 VideoRecorderCompose,後者會將 dumps 分發給其所有成員。

>>> env.transform.dump()

該變換也可以在資料集中使用,以儲存收集到的影片。與環境情況不同,影像將以批次形式出現。引數 skip 將使您能夠僅在特定間隔儲存影像。

>>> from torchrl.data.datasets import OpenXExperienceReplay
>>> from torchrl.envs import Compose
>>> from torchrl.record import VideoRecorder, CSVLogger
>>> # Create a logger that saves videos as mp4 using 24 frames per sec
>>> logger = CSVLogger("./dump", video_format="mp4", video_fps=24)
>>> # We use the VideoRecorder transform to save register the images coming from the batch.
>>> #  Setting the fps to 12 overrides the one set in the logger, not doing so keeps it unchanged.
>>> t = VideoRecorder(logger=logger, tag="pixels", in_keys=[("next", "observation", "image")], fps=12)
>>> # Each batch of data will have 10 consecutive videos of 200 frames each (maximum, since strict_length=False)
>>> dataset = OpenXExperienceReplay("cmu_stretch", batch_size=2000, slice_len=200,
...             download=True, strict_length=False,
...             transform=t)
>>> # Get a batch of data and visualize it
>>> for data in dataset:
...     t.dump()
...     break

您的影片可在 ./cheetah_videos/cheetah/videos/run_video_0.mp4 下找到!

文件

查閱 PyTorch 的全面開發者文件

檢視文件

教程

獲取針對初學者和高階開發者的深入教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源