EnvCreator¶
- class torchrl.envs.EnvCreator(create_env_fn: Callable[[...], EnvBase], create_env_kwargs: Optional[Dict] = None, share_memory: bool = True)[源代码]¶
环境创建器类。
EnvCreator 是一个通用的环境创建器类,可以在多进程上下文中创建环境时替换 lambda 函数。如果在子进程上创建的环境必须与主进程共享信息(例如,用于 VecNorm 转换),EnvCreator 将将指向共享内存中 tensordict 的指针传递到每个进程,以便它们全部同步。
- 参数:
create_env_fn (callable) – 返回 EnvBase 实例的可调用对象。
create_env_kwargs (dict, optional) – 环境创建器的关键字参数。
share_memory (bool, optional) – 如果为 False,则环境产生的 tensordict 不会放置在共享内存中。
示例
>>> # We create the same environment on 2 processes using VecNorm >>> # and check that the discounted count of observations match on >>> # both workers, even if one has not executed any step >>> import time >>> from torchrl.envs.libs.gym import GymEnv >>> from torchrl.envs.transforms import VecNorm, TransformedEnv >>> from torchrl.envs import EnvCreator >>> from torch import multiprocessing as mp >>> env_fn = lambda: TransformedEnv(GymEnv("Pendulum-v1"), VecNorm()) >>> env_creator = EnvCreator(env_fn) >>> >>> def test_env1(env_creator): ... env = env_creator() ... tensordict = env.reset() ... for _ in range(10): ... env.rand_step(tensordict) ... if tensordict.get(("next", "done")): ... tensordict = env.reset(tensordict) ... print("env 1: ", env.transform._td.get(("next", "observation_count"))) >>> >>> def test_env2(env_creator): ... env = env_creator() ... time.sleep(5) ... print("env 2: ", env.transform._td.get(("next", "observation_count"))) >>> >>> if __name__ == "__main__": ... ps = [] ... p1 = mp.Process(target=test_env1, args=(env_creator,)) ... p1.start() ... ps.append(p1) ... p2 = mp.Process(target=test_env2, args=(env_creator,)) ... p2.start() ... ps.append(p1) ... for p in ps: ... p.join() env 1: tensor([11.9934]) env 2: tensor([11.9934])