快捷方式

EnvCreator

class torchrl.envs.EnvCreator(create_env_fn: Callable[[...], EnvBase], create_env_kwargs: Optional[Dict] = None, share_memory: bool = True)[源代码]

环境创建器类。

EnvCreator 是一个通用的环境创建器类,可以在多进程上下文中创建环境时替换 lambda 函数。如果在子进程上创建的环境必须与主进程共享信息(例如,用于 VecNorm 转换),EnvCreator 将将指向共享内存中 tensordict 的指针传递到每个进程,以便它们全部同步。

参数:
  • create_env_fn (callable) – 返回 EnvBase 实例的可调用对象。

  • create_env_kwargs (dict, optional) – 环境创建器的关键字参数。

  • share_memory (bool, optional) – 如果为 False,则环境产生的 tensordict 不会放置在共享内存中。

示例

>>> # We create the same environment on 2 processes using VecNorm
>>> # and check that the discounted count of observations match on
>>> # both workers, even if one has not executed any step
>>> import time
>>> from torchrl.envs.libs.gym import GymEnv
>>> from torchrl.envs.transforms import VecNorm, TransformedEnv
>>> from torchrl.envs import EnvCreator
>>> from torch import multiprocessing as mp
>>> env_fn = lambda: TransformedEnv(GymEnv("Pendulum-v1"), VecNorm())
>>> env_creator = EnvCreator(env_fn)
>>>
>>> def test_env1(env_creator):
...     env = env_creator()
...     tensordict = env.reset()
...     for _ in range(10):
...         env.rand_step(tensordict)
...         if tensordict.get(("next", "done")):
...             tensordict = env.reset(tensordict)
...     print("env 1: ", env.transform._td.get(("next", "observation_count")))
>>>
>>> def test_env2(env_creator):
...     env = env_creator()
...     time.sleep(5)
...     print("env 2: ", env.transform._td.get(("next", "observation_count")))
>>>
>>> if __name__ == "__main__":
...     ps = []
...     p1 = mp.Process(target=test_env1, args=(env_creator,))
...     p1.start()
...     ps.append(p1)
...     p2 = mp.Process(target=test_env2, args=(env_creator,))
...     p2.start()
...     ps.append(p1)
...     for p in ps:
...         p.join()
env 1:  tensor([11.9934])
env 2:  tensor([11.9934])

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源