快捷方式

VideoRecorder

torchrl.record.VideoRecorder(logger: Logger, tag: str, in_keys: Optional[Sequence[NestedKey]] = None, skip: int | None = None, center_crop: Optional[int] = None, make_grid: bool | None = None, out_keys: Optional[Sequence[NestedKey]] = None, fps: int | None = None, **kwargs) None[source]

Video Recorder 变换。

将记录来自环境的一系列观测结果,并在需要时将其写入 Logger 对象。

参数:
  • logger (Logger) – 视频应写入的 Logger 实例。要将视频保存为 memmap 张量或 mp4 文件,请使用 CSVLogger 类。

  • tag (str) – 日志记录器中的视频标签。

  • in_keys (Sequence of NestedKey, 可选) – 用于生成视频的读取键。默认为 "pixels"

  • skip (int) – 输出视频的帧间隔。如果变换具有父环境,则默认为 2,否则为 1

  • center_crop (int, 可选) – 中心方形裁剪的值。

  • make_grid (bool, 可选) – 如果为 True,则假设提供形状为 [B x W x H x 3] 的张量,其中 B 是批量大小,将创建一个网格。如果变换具有父环境,则默认为 True,否则为 False

  • out_keys (sequence of NestedKey, 可选) – 目标键。如果未提供,则默认为 in_keys

  • fps (int, 可选) – 输出视频的每秒帧数 (Frames per second)。默认为日志记录器预定义的 fps,如果提供,则覆盖该值。

  • **kwargs (Dict[str, Any], 可选) – log_video() 的额外关键字参数。

示例

以下示例展示了如何在视频中保存一次 rollout。首先导入一些库

>>> from torchrl.record import VideoRecorder
>>> from torchrl.record.loggers.csv import CSVLogger
>>> from torchrl.envs import TransformedEnv, DMControlEnv

视频格式在日志记录器中选择。Wandb 和 tensorboard 会自行处理,CSV 接受各种视频格式。

>>> logger = CSVLogger(exp_name="cheetah", log_dir="cheetah_videos", video_format="mp4")

一些环境(例如,Atari 游戏)原生返回图像,有些则需要用户请求。查看 GymEnvDMControlEnv 以了解如何在这些上下文中渲染图像。

>>> base_env = DMControlEnv("cheetah", "run", from_pixels=True)
>>> env = TransformedEnv(base_env, VideoRecorder(logger=logger, tag="run_video"))
>>> env.rollout(100)

所有 transforms 都有一个 dump 函数,大多数情况下是空操作 (no-op),除了 VideoRecorderCompose,后者会将 dumps 分发给其所有成员。

>>> env.transform.dump()

该变换也可以在数据集中使用,以保存收集到的视频。与环境情况不同,图像将以批量形式出现。参数 skip 将使您能够仅在特定间隔保存图像。

>>> from torchrl.data.datasets import OpenXExperienceReplay
>>> from torchrl.envs import Compose
>>> from torchrl.record import VideoRecorder, CSVLogger
>>> # Create a logger that saves videos as mp4 using 24 frames per sec
>>> logger = CSVLogger("./dump", video_format="mp4", video_fps=24)
>>> # We use the VideoRecorder transform to save register the images coming from the batch.
>>> #  Setting the fps to 12 overrides the one set in the logger, not doing so keeps it unchanged.
>>> t = VideoRecorder(logger=logger, tag="pixels", in_keys=[("next", "observation", "image")], fps=12)
>>> # Each batch of data will have 10 consecutive videos of 200 frames each (maximum, since strict_length=False)
>>> dataset = OpenXExperienceReplay("cmu_stretch", batch_size=2000, slice_len=200,
...             download=True, strict_length=False,
...             transform=t)
>>> # Get a batch of data and visualize it
>>> for data in dataset:
...     t.dump()
...     break

您的视频可在 ./cheetah_videos/cheetah/videos/run_video_0.mp4 下找到!

文档

查阅 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题解答

查看资源