GRUModule¶
- class torchrl.modules.GRUModule(*args, **kwargs)[源代码]¶
GRU 模块的嵌入器。
此类为
torch.nn.GRU
添加以下功能与 TensorDict 兼容:隐藏状态会重新整形以匹配 tensordict 的批次大小。
可选的多步执行:使用 torch.nn,必须在
torch.nn.GRUCell
和torch.nn.GRU
之间进行选择,前者与单步输入兼容,后者与多步输入兼容。此类使这两种用法都成为可能。
构造后,模块不会设置为循环模式,即它将期望单步输入。
如果处于循环模式,则预计 tensordict 的最后一个维度标记步数。对 tensordict 的维度没有约束(除非对于时间输入,它必须大于 1)。
- 参数:
input_size – 输入 x 中预期的特征数
hidden_size – 隐藏状态 h 中的特征数
num_layers – 循环层的数量。例如,设置
num_layers=2
表示将两个 GRU 堆叠在一起形成一个堆叠 GRU,第二个 GRU 接收第一个 GRU 的输出并计算最终结果。默认值:1bias – 如果
False
,则该层不使用偏差权重。默认值:True
dropout – 如果非零,则在除最后一层之外的每个 GRU 层的输出上引入一个Dropout层,其 dropout 概率等于
dropout
。默认值:0python_based – 如果
True
,则将使用 GRU 单元的完整 Python 实现。默认值:False
- 关键字参数:
in_key (str 或 str 元组) – 模块的输入键。与
in_keys
互斥使用。如果提供,则循环键假定为 [“recurrent_state”],并且in_key
将附加在此之前。in_keys (str 列表) – 对应于输入值和循环条目的两个字符串。与
in_key
互斥使用。out_key (str 或 str 元组) – 模块的输出键。与
out_keys
互斥使用。如果提供,则循环键假定为 [(“recurrent_state”)],并且out_key
将附加在此之前。out_keys (str 列表) –
对应于输出值、第一和第二个隐藏键的两个字符串。.. 注意
For a better integration with TorchRL's environments, the best naming for the output hidden key is ``("next", <custom_key>)``, such that the hidden values are passed from step to step during a rollout.
device (torch.device 或 兼容) – 模块的设备。
gru (torch.nn.GRU, 可选) – 要包装的 GRU 实例。与其他 nn.GRU 参数互斥使用。
- 变量:
recurrent_mode – 返回模块的循环模式。
注意
此模块依赖于输入 TensorDict 中存在的特定
recurrent_state
键。要生成一个TensorDictPrimer
转换,该转换将自动将隐藏状态添加到环境 TensorDict 中,请使用make_tensordict_primer()
方法。如果此类是较大模块中的子模块,则可以在父模块上调用get_primers_from_module()
方法以自动生成所有子模块(包括此模块)所需的 primer 转换。示例
>>> from torchrl.envs import TransformedEnv, InitTracker >>> from torchrl.envs import GymEnv >>> from torchrl.modules import MLP >>> from torch import nn >>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod >>> env = TransformedEnv(GymEnv("Pendulum-v1"), InitTracker()) >>> gru_module = GRUModule( ... input_size=env.observation_spec["observation"].shape[-1], ... hidden_size=64, ... in_keys=["observation", "rs"], ... out_keys=["intermediate", ("next", "rs")]) >>> mlp = MLP(num_cells=[64], out_features=1) >>> policy = Seq(gru_module, Mod(mlp, in_keys=["intermediate"], out_keys=["action"])) >>> policy(env.reset()) TensorDict( fields={ action: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False), done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), intermediate: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.float32, is_shared=False), is_init: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), next: TensorDict( fields={ rs: Tensor(shape=torch.Size([1, 64]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=cpu, is_shared=False), observation: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([]), device=cpu, is_shared=False) >>> gru_module_training = gru_module.set_recurrent_mode() >>> policy_training = Seq(gru_module, Mod(mlp, in_keys=["intermediate"], out_keys=["action"])) >>> traj_td = env.rollout(3) # some random temporal data >>> traj_td = policy_training(traj_td) >>> print(traj_td) TensorDict( fields={ action: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False), done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), intermediate: Tensor(shape=torch.Size([3, 64]), device=cpu, dtype=torch.float32, is_shared=False), is_init: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), next: TensorDict( fields={ done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), is_init: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), observation: Tensor(shape=torch.Size([3, 3]), device=cpu, dtype=torch.float32, is_shared=False), reward: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False), rs: Tensor(shape=torch.Size([3, 1, 64]), device=cpu, dtype=torch.float32, is_shared=False), terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([3]), device=cpu, is_shared=False), observation: Tensor(shape=torch.Size([3, 3]), device=cpu, dtype=torch.float32, is_shared=False), terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([3]), device=cpu, is_shared=False)
- forward(tensordict: TensorDictBase)[源代码]¶
定义每次调用时执行的计算。
应由所有子类覆盖。
注意
尽管正向传递的步骤需要在此函数中定义,但应随后调用
Module
实例,而不是此函数,因为前者负责运行已注册的钩子,而后者则会静默忽略它们。
- make_tensordict_primer()[源代码]¶
为环境制作 tensordict primer。
TensorDictPrimer
对象将确保策略在回滚执行期间了解补充输入和输出(循环状态)。这样,数据可以在进程之间共享并得到正确处理。如果环境中没有包含
TensorDictPrimer
,可能会导致行为定义不明确,例如在并行设置中,某个步骤涉及将新的循环状态从"next"
复制到根 tensordict,而 meth:~torchrl.EnvBase.step_mdp 方法将无法执行此操作,因为循环状态未在环境规范中注册。示例
>>> from torchrl.collectors import SyncDataCollector >>> from torchrl.envs import TransformedEnv, InitTracker >>> from torchrl.envs import GymEnv >>> from torchrl.modules import MLP, LSTMModule >>> from torch import nn >>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod >>> >>> env = TransformedEnv(GymEnv("Pendulum-v1"), InitTracker()) >>> gru_module = GRUModule( ... input_size=env.observation_spec["observation"].shape[-1], ... hidden_size=64, ... in_keys=["observation", "rs"], ... out_keys=["intermediate", ("next", "rs")]) >>> mlp = MLP(num_cells=[64], out_features=1) >>> policy = Seq(gru_module, Mod(mlp, in_keys=["intermediate"], out_keys=["action"])) >>> policy(env.reset()) >>> env = env.append_transform(gru_module.make_tensordict_primer()) >>> data_collector = SyncDataCollector( ... env, ... policy, ... frames_per_batch=10 ... ) >>> for data in data_collector: ... print(data) ... break
- set_recurrent_mode(mode: bool = True)[source]¶
返回模块的新副本,该副本共享相同的 gru 模型,但具有不同的
recurrent_mode
属性(如果不同)。创建副本是为了使模块能够在代码的不同部分(推理与训练)中使用不同的行为。
示例
>>> from torchrl.envs import GymEnv, TransformedEnv, InitTracker, step_mdp >>> from torchrl.modules import MLP >>> from tensordict import TensorDict >>> from torch import nn >>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod >>> env = TransformedEnv(GymEnv("Pendulum-v1"), InitTracker()) >>> gru = nn.GRU(input_size=env.observation_spec["observation"].shape[-1], hidden_size=64, batch_first=True) >>> gru_module = GRUModule(gru=gru, in_keys=["observation", "hidden"], out_keys=["intermediate", ("next", "hidden")]) >>> mlp = MLP(num_cells=[64], out_features=1) >>> # building two policies with different behaviours: >>> policy_inference = Seq(gru_module, Mod(mlp, in_keys=["intermediate"], out_keys=["action"])) >>> policy_training = Seq(gru_module.set_recurrent_mode(True), Mod(mlp, in_keys=["intermediate"], out_keys=["action"])) >>> traj_td = env.rollout(3) # some random temporal data >>> traj_td = policy_training(traj_td) >>> # let's check that both return the same results >>> td_inf = TensorDict({}, traj_td.shape[:-1]) >>> for td in traj_td.unbind(-1): ... td_inf = td_inf.update(td.select("is_init", "observation", ("next", "observation"))) ... td_inf = policy_inference(td_inf) ... td_inf = step_mdp(td_inf) ... >>> torch.testing.assert_close(td_inf["hidden"], traj_td[..., -1]["next", "hidden"])