EnvCreator¶
- 类 torchrl.envs.EnvCreator(create_env_fn: Callable[[...], EnvBase], create_env_kwargs: Optional[Dict] = None, share_memory: bool = True, **kwargs)[源码]¶
环境创建器类。
EnvCreator 是一个通用的环境创建器类,可在多进程环境下创建环境时替代 lambda 函数。如果在子进程中创建的环境需要与主进程共享信息(例如用于 VecNorm 变换),EnvCreator 会将共享内存中 tensordict 的指针传递给每个进程,以便它们保持同步。
- 参数:
create_env_fn (callable) – 一个可调用对象,返回一个 EnvBase 实例。
create_env_kwargs (dict, optional) – 环境创建器的 kwargs。
share_memory (bool, optional) – 如果为 False,环境生成的 tensordict 不会放在共享内存中。
**kwargs – 在构建环境时要传递的额外关键字参数。
示例
>>> # We create the same environment on 2 processes using VecNorm >>> # and check that the discounted count of observations match on >>> # both workers, even if one has not executed any step >>> import time >>> from torchrl.envs.libs.gym import GymEnv >>> from torchrl.envs.transforms import VecNorm, TransformedEnv >>> from torchrl.envs import EnvCreator >>> from torch import multiprocessing as mp >>> env_fn = lambda: TransformedEnv(GymEnv("Pendulum-v1"), VecNorm()) >>> env_creator = EnvCreator(env_fn) >>> >>> def test_env1(env_creator): ... env = env_creator() ... tensordict = env.reset() ... for _ in range(10): ... env.rand_step(tensordict) ... if tensordict.get(("next", "done")): ... tensordict = env.reset(tensordict) ... print("env 1: ", env.transform._td.get(("next", "observation_count"))) >>> >>> def test_env2(env_creator): ... env = env_creator() ... time.sleep(5) ... print("env 2: ", env.transform._td.get(("next", "observation_count"))) >>> >>> if __name__ == "__main__": ... ps = [] ... p1 = mp.Process(target=test_env1, args=(env_creator,)) ... p1.start() ... ps.append(p1) ... p2 = mp.Process(target=test_env2, args=(env_creator,)) ... p2.start() ... ps.append(p1) ... for p in ps: ... p.join() env 1: tensor([11.9934]) env 2: tensor([11.9934])
- make_variant(**kwargs) EnvCreator [源码]¶
创建 EnvCreator 的一个变体,指向相同的底层元数据,但在构建时使用不同的关键字参数。
这对于共享状态的变换 (transforms) 可能有用,例如
TrajCounter
。示例
>>> from torchrl.envs import GymEnv >>> env_creator_pendulum = EnvCreator(GymEnv, env_name="Pendulum-v1") >>> env_creator_cartpole = env_creator_pendulum(env_name="CartPole-v1")