快捷方式

GRUModule

class torchrl.modules.GRUModule(*args, **kwargs)[source]

GRU 模块的嵌入器。

此类别为 torch.nn.GRU 添加了以下功能:

  • 与 TensorDict 的兼容性:隐藏状态会被重塑以匹配 tensordict 的批次大小。

  • 可选的多步执行:使用 torch.nn 时,必须在 torch.nn.GRUCelltorch.nn.GRU 之间进行选择,前者与单步输入兼容,后者与多步兼容。此类别同时支持这两种用法。

构建后,模块默认**不**处于循环模式,即它会期望单步输入。

如果在循环模式下,tensordict 的最后一个维度预期表示步数。tensordict 的维度没有限制(除了对于时间序列输入,维度必须大于一)。

参数:
  • input_size – 输入 x 中期望的特征数量

  • hidden_size – 隐藏状态 h 中的特征数量

  • num_layers – 循环层数量。例如,设置 num_layers=2 意味着将两个 GRU 堆叠在一起形成一个 堆叠 GRU,其中第二个 GRU 接收第一个 GRU 的输出并计算最终结果。默认值:1

  • bias – 如果为 False,则层不使用偏置权重。默认值:True

  • dropout – 如果非零,则在除最后一层外的每个 GRU 层的输出上引入一个 Dropout 层,丢弃概率等于 dropout。默认值:0

  • python_based – 如果为 True,将使用完整的 Python 实现的 GRU Cell。默认值:False

关键字参数:
  • in_key (strtuple of str) – 模块的输入键。与 in_keys 互斥使用。如果提供,循环键假定为 [“recurrent_state”],并且 in_key 将在此之前添加。

  • in_keys (list of str) – 对应于输入值和循环条目的字符串对。与 in_key 互斥。

  • out_key (strtuple of str) – 模块的输出键。与 out_keys 互斥使用。如果提供,循环键假定为 [(“recurrent_state”)],并且 out_key 将在这些键之前添加。

  • out_keys (list of str) –

    对应于输出值和隐藏状态的字符串对。

    For a better integration with TorchRL's environments, the best naming
    for the output hidden key is ``("next", <custom_key>)``, such
    that the hidden values are passed from step to step during a rollout.
    

  • device (torch.device兼容类型) – 模块的设备。

  • gru (torch.nn.GRU, 可选) – 要包装的 GRU 实例。与其他 nn.GRU 参数互斥。

  • default_recurrent_mode (bool, 可选) – 如果提供,则为未被 set_recurrent_mode 上下文管理器/装饰器覆盖时的循环模式。默认为 False

变量:

recurrent_mode – 返回模块的循环模式。

set_recurrent_mode()[source]

控制模块是否应在循环模式下执行。

make_tensordict_primer()[source]

创建 TensorDictPrimer 变换,使环境能够感知 RNN 的循环状态。

注意

此模块依赖于输入 TensorDict 中存在特定的 recurrent_state 键。要生成一个会自动向环境 TensorDict 添加隐藏状态的 TensorDictPrimer 变换,请使用方法 make_tensordict_primer()。如果此类别是较大模块的子模块,则可以在父模块上调用方法 get_primers_from_module(),以自动生成所有子模块(包括此模块)所需的 primer 变换。

示例

>>> from torchrl.envs import TransformedEnv, InitTracker
>>> from torchrl.envs import GymEnv
>>> from torchrl.modules import MLP
>>> from torch import nn
>>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod
>>> env = TransformedEnv(GymEnv("Pendulum-v1"), InitTracker())
>>> gru_module = GRUModule(
...     input_size=env.observation_spec["observation"].shape[-1],
...     hidden_size=64,
...     in_keys=["observation", "rs"],
...     out_keys=["intermediate", ("next", "rs")])
>>> mlp = MLP(num_cells=[64], out_features=1)
>>> policy = Seq(gru_module, Mod(mlp, in_keys=["intermediate"], out_keys=["action"]))
>>> policy(env.reset())
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
        done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        intermediate: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.float32, is_shared=False),
        is_init: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                rs: Tensor(shape=torch.Size([1, 64]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([]),
            device=cpu,
            is_shared=False),
        observation: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([]),
    device=cpu,
    is_shared=False)
>>> gru_module_training = gru_module.set_recurrent_mode()
>>> policy_training = Seq(gru_module, Mod(mlp, in_keys=["intermediate"], out_keys=["action"]))
>>> traj_td = env.rollout(3) # some random temporal data
>>> traj_td = policy_training(traj_td)
>>> print(traj_td)
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        intermediate: Tensor(shape=torch.Size([3, 64]), device=cpu, dtype=torch.float32, is_shared=False),
        is_init: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                is_init: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                observation: Tensor(shape=torch.Size([3, 3]), device=cpu, dtype=torch.float32, is_shared=False),
                reward: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                rs: Tensor(shape=torch.Size([3, 1, 64]), device=cpu, dtype=torch.float32, is_shared=False),
                terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
            batch_size=torch.Size([3]),
            device=cpu,
            is_shared=False),
        observation: Tensor(shape=torch.Size([3, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([3]),
    device=cpu,
    is_shared=False)
forward(tensordict: TensorDictBase = None)[source]

定义每次调用时执行的计算。

应由所有子类覆盖。

注意

虽然前向传播的实现需要在该函数中定义,但之后应该调用 Module 实例而不是此函数本身,因为前者会处理已注册的钩子,而后者会静默忽略它们。

make_cudnn_based() GRUModule[source]

将 GRU 层转换为基于 CuDNN 的版本。

返回:

自身

make_python_based() GRUModule[source]

将 GRU 层转换为基于 python 的版本。

返回:

自身

make_tensordict_primer()[source]

为环境创建 tensordict primer。

一个 TensorDictPrimer 对象将确保策略在 rollout 执行期间感知到补充的输入和输出(循环状态)。这样,数据可以在进程之间共享并得到妥善处理。

如果在环境中不包含 TensorDictPrimer,可能会导致行为定义不清,例如在并行设置中,一步操作涉及将新的循环状态从 "next" 复制到根 tensordict,而 meth:~torchrl.EnvBase.step_mdp 方法将无法完成此操作,因为循环状态未在环境规范中注册。

使用批量环境(例如 ParallelEnv)时,变换可以在单个环境实例级别(即内部设置了 tensordict primer 的批量变换环境)或批量环境实例级别(即常规环境的变换批次)使用。

有关生成给定模块所有 primer 的方法,请参见 torchrl.modules.utils.get_primers_from_module()

示例

>>> from torchrl.collectors import SyncDataCollector
>>> from torchrl.envs import TransformedEnv, InitTracker
>>> from torchrl.envs import GymEnv
>>> from torchrl.modules import MLP, LSTMModule
>>> from torch import nn
>>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod
>>>
>>> env = TransformedEnv(GymEnv("Pendulum-v1"), InitTracker())
>>> gru_module = GRUModule(
...     input_size=env.observation_spec["observation"].shape[-1],
...     hidden_size=64,
...     in_keys=["observation", "rs"],
...     out_keys=["intermediate", ("next", "rs")])
>>> mlp = MLP(num_cells=[64], out_features=1)
>>> policy = Seq(gru_module, Mod(mlp, in_keys=["intermediate"], out_keys=["action"]))
>>> policy(env.reset())
>>> env = env.append_transform(gru_module.make_tensordict_primer())
>>> data_collector = SyncDataCollector(
...     env,
...     policy,
...     frames_per_batch=10
... )
>>> for data in data_collector:
...     print(data)
...     break
set_recurrent_mode(mode: bool = True)[source]

[已弃用 - 请改用 torchrl.modules.set_recurrent_mode 上下文管理器] 返回模块的新副本,该副本共享相同的 gru 模型,但具有不同的 recurrent_mode 属性(如果不同)。

创建副本是为了使模块可以在代码的不同部分(推理与训练)中以不同的行为方式使用。

示例

>>> from torchrl.envs import GymEnv, TransformedEnv, InitTracker, step_mdp
>>> from torchrl.modules import MLP
>>> from tensordict import TensorDict
>>> from torch import nn
>>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod
>>> env = TransformedEnv(GymEnv("Pendulum-v1"), InitTracker())
>>> gru = nn.GRU(input_size=env.observation_spec["observation"].shape[-1], hidden_size=64, batch_first=True)
>>> gru_module = GRUModule(gru=gru, in_keys=["observation", "hidden"], out_keys=["intermediate", ("next", "hidden")])
>>> mlp = MLP(num_cells=[64], out_features=1)
>>> # building two policies with different behaviors:
>>> policy_inference = Seq(gru_module, Mod(mlp, in_keys=["intermediate"], out_keys=["action"]))
>>> policy_training = Seq(gru_module.set_recurrent_mode(True), Mod(mlp, in_keys=["intermediate"], out_keys=["action"]))
>>> traj_td = env.rollout(3) # some random temporal data
>>> traj_td = policy_training(traj_td)
>>> # let's check that both return the same results
>>> td_inf = TensorDict(batch_size=traj_td.shape[:-1])
>>> for td in traj_td.unbind(-1):
...     td_inf = td_inf.update(td.select("is_init", "observation", ("next", "observation")))
...     td_inf = policy_inference(td_inf)
...     td_inf = step_mdp(td_inf)
...
>>> torch.testing.assert_close(td_inf["hidden"], traj_td[..., -1]["next", "hidden"])

文档

查阅 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获取问题解答

查看资源