快捷方式

PixelRenderTransform

torchrl.record.PixelRenderTransform(out_keys: List[NestedKey] = None, preproc: Callable[[np.ndarray | torch.Tensor], np.ndarray | torch.Tensor] = None, as_non_tensor: bool = None, render_method: str = 'render', pass_tensordict: bool = False, **kwargs) None[源代码]

在父环境上调用 render 并将像素观测结果注册到 tensordict 中的转换。

此转换提供了一种替代方法,用于在实例化提供渲染的环境时使用 from_pixels 语法糖,如果渲染很昂贵,或者如果未实现 from_pixels。它可以在单个环境或批处理环境中使用。

参数::
  • out_keys (List[NestedKey] 或 Nested) – 注册像素观测结果的键列表。

  • preproc (Callable, 可选) – 预处理函数。可用于重新整形观测结果,或应用任何其他转换,使其能够在输出数据中注册。

  • as_non_tensor (bool, 可选) – 如果 True,则数据将被写为 NonTensorData,从而放宽形状要求。如果未提供,它将从输入数据类型和形状自动推断。

  • render_method (str, 可选) – 渲染方法的名称。默认为 "render"

  • pass_tensordict (bool, 可选) – 如果 True,则输入 tensordict 将传递给渲染方法。这使得无状态环境的渲染成为可能。默认为 False

  • **kwargs – 传递给渲染函数的额外关键字参数(例如 mode="rgb_array")。

示例

>>> from torchrl.envs import GymEnv, check_env_specs, ParallelEnv, EnvCreator
>>> from torchrl.record.loggers import CSVLogger
>>> from torchrl.record.recorder import PixelRenderTransform, VideoRecorder
>>>
>>> def make_env():
>>>     env = GymEnv("CartPole-v1", render_mode="rgb_array")
>>>     env = env.append_transform(PixelRenderTransform())
>>>     return env
>>>
>>> if __name__ == "__main__":
...     logger = CSVLogger("dummy", video_format="mp4")
...
...     env = ParallelEnv(4, EnvCreator(make_env))
...
...     env = env.append_transform(VideoRecorder(logger=logger, tag="pixels_record"))
...     env.rollout(3)
...
...     check_env_specs(env)
...
...     r = env.rollout(30)
...     print(env)
...     env.transform.dump()
...     env.close()

此转换还可以在批处理环境 render() 返回单个图像时使用

示例

>>> from torchrl.envs import check_env_specs
>>> from torchrl.envs.libs.vmas import VmasEnv
>>> from torchrl.record.loggers import CSVLogger
>>> from torchrl.record.recorder import PixelRenderTransform, VideoRecorder
>>>
>>> env = VmasEnv(
...     scenario="flocking",
...     num_envs=32,
...     continuous_actions=True,
...     max_steps=200,
...     device="cpu",
...     seed=None,
...     # Scenario kwargs
...     n_agents=5,
... )
>>>
>>> logger = CSVLogger("dummy", video_format="mp4")
>>>
>>> env = env.append_transform(PixelRenderTransform(mode="rgb_array", preproc=lambda x: x.copy()))
>>> env = env.append_transform(VideoRecorder(logger=logger, tag="pixels_record"))
>>>
>>> check_env_specs(env)
>>>
>>> r = env.rollout(30)
>>> env.transform[-1].dump()

可以使用 switch() 方法禁用此转换,该方法将打开或关闭渲染(也可以传递参数来控制此行为)。由于转换是 Module 实例,因此可以使用 apply() 来控制此行为

>>> def switch(module):
...     if isinstance(module, PixelRenderTransform):
...         module.switch()
>>> env.apply(switch)

文档

访问 PyTorch 的全面的开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源