快捷方式

AutoResetTransform

class torchrl.envs.transforms.AutoResetTransform(*, replace: Optional[bool] = None, fill_float='nan', fill_int=- 1, fill_bool=False)[source]

用于自动重置环境的转换。

此转换可以附加到任何自动重置环境,或者使用 env = SomeEnvClass(..., auto_reset=True) 自动附加。如果转换显式附加到环境,则必须使用 AutoResetEnv

自动重置环境必须具有以下属性(与此描述的差异应通过子类化此类来解决)

  • reset 函数可以在开始时(实例化后)调用一次,无论是否生效。reset 之后是否允许调用取决于环境本身。

  • 在 rollout 期间,任何 done 状态都将导致重置,并产生一个 observation,该 observation 不是当前 episode 的最后一个 observation,而是下一个 episode 的第一个 observation(此转换将提取并缓存此 observation,并使用一些任意值填充 obs)。

关键字参数:
  • replace (bool, optional) – 如果 False,则值会按原样放置在 "next" 条目中,即使它们无效。默认为 TrueFalse 值会覆盖任何后续的填充关键字参数。此参数也可以通过传递 auto_reset_replace 参数的构造方法传递:env = FooEnv(..., auto_reset=True, auto_reset_replace=False)

  • fill_float (floatstr, optional) – 终止 episode 的浮点张量的填充值。None 值表示不替换(值会按原样放置在 "next" 条目中,即使它们无效)。

  • fill_int (int, optional) – 终止 episode 的有符号整数张量的填充值。None 值表示不替换(值会按原样放置在 "next" 条目中,即使它们无效)。

  • fill_bool (bool, optional) – 终止 episode 的布尔张量的填充值。None 值表示不替换(值会按原样放置在 "next" 条目中,即使它们无效)。

参数仅在显式实例化转换时可用(而不是通过 EnvType(…, auto_reset=True))。

示例

>>> from torchrl.envs import GymEnv
>>> from torchrl.envs import set_gym_backend
>>> import torch
>>> torch.manual_seed(0)
>>>
>>> class AutoResettingGymEnv(GymEnv):
...     def _step(self, tensordict):
...         tensordict = super()._step(tensordict)
...         if tensordict["done"].any():
...             td_reset = super().reset()
...             tensordict.update(td_reset.exclude(*self.done_keys))
...         return tensordict
...
...     def _reset(self, tensordict=None):
...         if tensordict is not None and "_reset" in tensordict:
...             return tensordict.copy()
...         return super()._reset(tensordict)
>>>
>>> with set_gym_backend("gym"):
...     env = AutoResettingGymEnv("CartPole-v1", auto_reset=True, auto_reset_replace=True)
...     env.set_seed(0)
...     r = env.rollout(30, break_when_any_done=False)
>>> print(r["next", "done"].squeeze())
tensor([False, False, False, False, False, False, False, False, False, False,
        False, False, False,  True, False, False, False, False, False, False,
        False, False, False, False, False,  True, False, False, False, False])
>>> print("observation after reset are set as nan", r["next", "observation"])
observation after reset are set as nan tensor([[-4.3633e-02, -1.4877e-01,  1.2849e-02,  2.7584e-01],
        [-4.6609e-02,  4.6166e-02,  1.8366e-02, -1.2761e-02],
        [-4.5685e-02,  2.4102e-01,  1.8111e-02, -2.9959e-01],
        [-4.0865e-02,  4.5644e-02,  1.2119e-02, -1.2542e-03],
        [-3.9952e-02,  2.4059e-01,  1.2094e-02, -2.9009e-01],
        [-3.5140e-02,  4.3554e-01,  6.2920e-03, -5.7893e-01],
        [-2.6429e-02,  6.3057e-01, -5.2867e-03, -8.6963e-01],
        [-1.3818e-02,  8.2576e-01, -2.2679e-02, -1.1640e+00],
        [ 2.6972e-03,  1.0212e+00, -4.5959e-02, -1.4637e+00],
        [ 2.3121e-02,  1.2168e+00, -7.5232e-02, -1.7704e+00],
        [ 4.7457e-02,  1.4127e+00, -1.1064e-01, -2.0854e+00],
        [ 7.5712e-02,  1.2189e+00, -1.5235e-01, -1.8289e+00],
        [ 1.0009e-01,  1.0257e+00, -1.8893e-01, -1.5872e+00],
        [        nan,         nan,         nan,         nan],
        [-3.9405e-02, -1.7766e-01, -1.0403e-02,  3.0626e-01],
        [-4.2959e-02, -3.7263e-01, -4.2775e-03,  5.9564e-01],
        [-5.0411e-02, -5.6769e-01,  7.6354e-03,  8.8698e-01],
        [-6.1765e-02, -7.6292e-01,  2.5375e-02,  1.1820e+00],
        [-7.7023e-02, -9.5836e-01,  4.9016e-02,  1.4826e+00],
        [-9.6191e-02, -7.6387e-01,  7.8667e-02,  1.2056e+00],
        [-1.1147e-01, -9.5991e-01,  1.0278e-01,  1.5219e+00],
        [-1.3067e-01, -7.6617e-01,  1.3322e-01,  1.2629e+00],
        [-1.4599e-01, -5.7298e-01,  1.5848e-01,  1.0148e+00],
        [-1.5745e-01, -7.6982e-01,  1.7877e-01,  1.3527e+00],
        [-1.7285e-01, -9.6668e-01,  2.0583e-01,  1.6956e+00],
        [        nan,         nan,         nan,         nan],
        [-4.3962e-02,  1.9845e-01, -4.5015e-02, -2.5903e-01],
        [-3.9993e-02,  3.9418e-01, -5.0196e-02, -5.6557e-01],
        [-3.2109e-02,  5.8997e-01, -6.1507e-02, -8.7363e-01],
        [-2.0310e-02,  3.9574e-01, -7.8980e-02, -6.0090e-01]])
forward(tensordict: TensorDictBase) TensorDictBase[source]

读取输入的 tensordict,并为选定的键应用转换。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得问题解答

查看资源