快捷方式

TrajCounter

class torchrl.envs.transforms.TrajCounter(out_key: NestedKey = 'traj_count')[source]

全局轨迹计数器变换。

TrajCounter 可用于计算任何 TorchRL 环境中的轨迹数量(即调用 reset 的次数)。此变换可在单个节点内的多个进程中工作(见下注)。单个变换只能计算与单个完成状态相关的轨迹,但只要嵌套完成状态的前缀与计数器键的前缀匹配,就可以接受嵌套完成状态。

参数:

out_key (NestedKey, optional) – 轨迹计数器的条目名称。默认为 "traj_count"

示例

>>> from torchrl.envs import GymEnv, StepCounter, TrajCounter
>>> env = GymEnv("Pendulum-v1").append_transform(StepCounter(6))
>>> env = env.append_transform(TrajCounter())
>>> r = env.rollout(18, break_when_any_done=False)  # 18 // 6 = 3 trajectories
>>> r["next", "traj_count"]
tensor([[0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [2],
        [2],
        [2],
        [2],
        [2],
        [2]])

注意

在 workers 之间共享轨迹计数器可以通过多种方式实现,但这通常涉及将环境封装在 EnvCreator 中。不这样做可能会在变换序列化期间导致错误。计数器将在 workers 之间共享,这意味着在任何时间点,都可以保证不会有两个环境共享相同的轨迹计数(并且每个 (步数, 轨迹数) 对都将是唯一的)。以下是跨进程共享 TrajCounter 对象的有效方法示例

>>> # Option 1: Create the trajectory counter outside the environment.
>>> #  This requires the counter to be cloned within the transformed env, as a single transform object cannot have two parents.
>>> t = TrajCounter()
>>> def make_env(max_steps=4, t=t):
...     # See CountingEnv in torchrl.test.mocking_classes
...     env = TransformedEnv(CountingEnv(max_steps=max_steps), t.clone())
...     env.transform.transform_observation_spec(env.base_env.observation_spec)
...     return env
>>> penv = ParallelEnv(
...     2,
...     [EnvCreator(make_env, max_steps=4), EnvCreator(make_env, max_steps=5)],
...     mp_start_method="spawn",
... )
>>> # Option 2: Create the transform within the constructor.
>>> #  In this scenario, we still need to tell each sub-env what kwarg has to be used.
>>> #  Both EnvCreator and ParallelEnv offer that possibility.
>>> def make_env(max_steps=4):
...     t = TrajCounter()
...     env = TransformedEnv(CountingEnv(max_steps=max_steps), t)
...     env.transform.transform_observation_spec(env.base_env.observation_spec)
...     return env
>>> make_env_c0 = EnvCreator(make_env)
>>> # Create a variant of the env with different kwargs
>>> make_env_c1 = make_env_c0.make_variant(max_steps=5)
>>> penv = ParallelEnv(
...     2,
...     [make_env_c0, make_env_c1],
...     mp_start_method="spawn",
... )
>>> # Alternatively, pass the kwargs to the ParallelEnv
>>> penv = ParallelEnv(
...     2,
...     [make_env_c0, make_env_c0],
...     create_env_kwargs=[{"max_steps": 5}, {"max_steps": 4}],
...     mp_start_method="spawn",
... )
forward(tensordict: TensorDictBase) TensorDictBase[source]

读取输入的 tensordict,并对选定的键应用变换。

load_state_dict(state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False)[source]

state_dict 中的参数和缓冲区复制到此模块及其后代中。

如果 strictTrue,则 state_dict 的键必须与此模块的 state_dict() 函数返回的键完全匹配。

警告

如果 assignTrue,则除非 get_swap_module_params_on_conversion()True,否则必须在调用 load_state_dict 后创建优化器。

参数:
  • state_dict (dict) – 包含参数和持久性缓冲区的字典。

  • strict (bool, optional) – 是否严格执行 state_dict 中的键与此模块的 state_dict() 函数返回的键匹配。默认为 True

  • assign (bool, optional) – 当设置为 False 时,保留当前模块中张量的属性,而设置为 True 时保留 state dict 中张量的属性。唯一的例外是 requires_grad 字段。默认为 False

返回:

  • missing_keys 是一个字符串列表,包含预期但

    提供的 state_dict 中缺失的键。

  • unexpected_keys 是一个字符串列表,包含此模块

    未预期但在提供的 state_dict 中存在的键。

返回类型:

包含 missing_keysunexpected_keys 字段的 NamedTuple

注意

如果参数或缓冲区注册为 None 且其对应的键存在于 state_dict 中,load_state_dict() 将引发 RuntimeError

state_dict(*args, destination=None, prefix='', keep_vars=False)[source]

返回一个包含模块完整状态引用的字典。

包括参数和持久性缓冲区(例如,运行平均值)。键是对应的参数和缓冲区名称。设置为 None 的参数和缓冲区不包括在内。

注意

返回的对象是浅拷贝。它包含对模块参数和缓冲区的引用。

警告

目前 state_dict() 也按顺序接受 destination、prefix 和 keep_vars 的位置参数。但是,此用法已被弃用,未来版本将强制使用关键字参数。

警告

请避免使用参数 destination,因为它不面向终端用户设计。

参数:
  • destination (dict, optional) – 如果提供,模块状态将更新到该字典中并返回同一对象。否则,将创建并返回一个 OrderedDict。默认为 None

  • prefix (str, optional) – 添加到参数和缓冲区名称前的前缀,用于在 state_dict 中构成键。默认为 ''

  • keep_vars (bool, optional) – 默认情况下,state dict 中返回的 Tensor 与 autograd 分离。如果设置为 True,则不会进行分离。默认为 False

返回:

包含模块完整状态的字典

返回类型:

dict

示例

>>> # xdoctest: +SKIP("undefined vars")
>>> module.state_dict().keys()
['bias', 'weight']
transform_observation_spec(observation_spec: Composite) Composite[source]

变换观测规范,使结果规范与变换映射匹配。

参数:

observation_spec (TensorSpec) – 变换前的规范

返回:

变换后预期的规范

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得问题解答

查看资源