快捷方式

EnvCreator

torchrl.envs.EnvCreator(create_env_fn: Callable[[...], EnvBase], create_env_kwargs: Optional[Dict] = None, share_memory: bool = True, **kwargs)[源码]

环境创建器类。

EnvCreator 是一个通用的环境创建器类,可在多进程环境下创建环境时替代 lambda 函数。如果在子进程中创建的环境需要与主进程共享信息(例如用于 VecNorm 变换),EnvCreator 会将共享内存中 tensordict 的指针传递给每个进程,以便它们保持同步。

参数:
  • create_env_fn (callable) – 一个可调用对象,返回一个 EnvBase 实例。

  • create_env_kwargs (dict, optional) – 环境创建器的 kwargs。

  • share_memory (bool, optional) – 如果为 False,环境生成的 tensordict 不会放在共享内存中。

  • **kwargs – 在构建环境时要传递的额外关键字参数。

示例

>>> # We create the same environment on 2 processes using VecNorm
>>> # and check that the discounted count of observations match on
>>> # both workers, even if one has not executed any step
>>> import time
>>> from torchrl.envs.libs.gym import GymEnv
>>> from torchrl.envs.transforms import VecNorm, TransformedEnv
>>> from torchrl.envs import EnvCreator
>>> from torch import multiprocessing as mp
>>> env_fn = lambda: TransformedEnv(GymEnv("Pendulum-v1"), VecNorm())
>>> env_creator = EnvCreator(env_fn)
>>>
>>> def test_env1(env_creator):
...     env = env_creator()
...     tensordict = env.reset()
...     for _ in range(10):
...         env.rand_step(tensordict)
...         if tensordict.get(("next", "done")):
...             tensordict = env.reset(tensordict)
...     print("env 1: ", env.transform._td.get(("next", "observation_count")))
>>>
>>> def test_env2(env_creator):
...     env = env_creator()
...     time.sleep(5)
...     print("env 2: ", env.transform._td.get(("next", "observation_count")))
>>>
>>> if __name__ == "__main__":
...     ps = []
...     p1 = mp.Process(target=test_env1, args=(env_creator,))
...     p1.start()
...     ps.append(p1)
...     p2 = mp.Process(target=test_env2, args=(env_creator,))
...     p2.start()
...     ps.append(p1)
...     for p in ps:
...         p.join()
env 1:  tensor([11.9934])
env 2:  tensor([11.9934])
make_variant(**kwargs) EnvCreator[源码]

创建 EnvCreator 的一个变体,指向相同的底层元数据,但在构建时使用不同的关键字参数。

这对于共享状态的变换 (transforms) 可能有用,例如 TrajCounter

示例

>>> from torchrl.envs import GymEnv
>>> env_creator_pendulum = EnvCreator(GymEnv, env_name="Pendulum-v1")
>>> env_creator_cartpole = env_creator_pendulum(env_name="CartPole-v1")

文档

查阅 PyTorch 的完整开发者文档

查看文档

教程

获取针对初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得解答

查看资源