LSTMModule¶
- class torchrl.modules.LSTMModule(*args, **kwargs)[source]¶
一个用于 LSTM 模块的嵌入器。
该类为
torch.nn.LSTM
添加了以下功能:与 TensorDict 的兼容性:隐状态被重塑以匹配 tensordict 的批次大小。
可选的多步执行:使用 torch.nn 时,必须在
torch.nn.LSTMCell
和torch.nn.LSTM
之间选择,前者兼容单步输入,后者兼容多步输入。该类同时支持这两种用法。
构建后,模块 不 设置为循环模式,即它将预期单步输入。
如果在循环模式下,tensordict 的最后一个维度预计表示步数。tensordict 的维度没有限制(除了对于时间输入来说必须大于一)。
注意
该类可以处理沿时间维度的多个连续轨迹,但 在这些情况下,不应信任最终的隐状态值(即,它们不应被用于连续轨迹)。原因是 LSTM 只返回最后一个隐状态值,对于我们提供的填充输入,该值可能对应于一个填充零的输入。
- 参数:
input_size – 输入 x 中预期的特征数量
hidden_size – 隐状态 h 中的特征数量
num_layers – 循环层的数量。例如,设置
num_layers=2
意味着将两个 LSTM 堆叠在一起形成一个 堆叠 LSTM,第二个 LSTM 接收第一个 LSTM 的输出并计算最终结果。默认值:1bias – 如果为
False
,则该层不使用偏置权重 b_ih 和 b_hh。默认值:True
dropout – 如果非零,则在除最后一层以外的每个 LSTM 层的输出上引入一个 Dropout 层,dropout 概率等于
dropout
。默认值:0python_based – 如果为
True
,将使用 LSTM 单元的完整 Python 实现。默认值:False
- 关键字参数:
in_key (str 或 str 元组) – 模块的输入键。与
in_keys
互斥。如果提供,循环键假定为 [“recurrent_state_h”, “recurrent_state_c”],in_key
将被附加在它们之前。in_keys (str 列表) – 对应于输入值、第一个和第二个隐状态键的字符串三元组。与
in_key
互斥。out_key (str 或 str 元组) – 模块的输出键。与
out_keys
互斥。如果提供,循环键假定为 [(“next”, “recurrent_state_h”), (“next”, “recurrent_state_c”)],out_key
将被附加在它们之前。out_keys (str 列表) –
对应于输出值、第一个和第二个隐状态键的字符串三元组。.. 注意
For a better integration with TorchRL's environments, the best naming for the output hidden key is ``("next", <custom_key>)``, such that the hidden values are passed from step to step during a rollout.
device (torch.device 或 兼容类型) – 模块的设备。
lstm (torch.nn.LSTM, 可选) – 要封装的 LSTM 实例。与其他 nn.LSTM 参数互斥。
default_recurrent_mode (bool, 可选) – 如果提供,则指定循环模式,除非已被
set_recurrent_mode
上下文管理器/装饰器覆盖。默认值为False
。
- 变量:
recurrent_mode – 返回模块的循环模式。
注意
该模块依赖于输入 TensorDict 中存在特定的
recurrent_state
键。要生成一个TensorDictPrimer
transform,该 transform 将自动把隐状态添加到环境 TensorDict 中,请使用方法make_tensordict_primer()
。如果该类是一个更大模块的子模块,可以在父模块上调用方法get_primers_from_module()
,以自动生成包括本模块在内的所有子模块所需的 primer transforms。示例
>>> from torchrl.envs import TransformedEnv, InitTracker >>> from torchrl.envs import GymEnv >>> from torchrl.modules import MLP >>> from torch import nn >>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod >>> env = TransformedEnv(GymEnv("Pendulum-v1"), InitTracker()) >>> lstm_module = LSTMModule( ... input_size=env.observation_spec["observation"].shape[-1], ... hidden_size=64, ... in_keys=["observation", "rs_h", "rs_c"], ... out_keys=["intermediate", ("next", "rs_h"), ("next", "rs_c")]) >>> mlp = MLP(num_cells=[64], out_features=1) >>> policy = Seq(lstm_module, Mod(mlp, in_keys=["intermediate"], out_keys=["action"])) >>> policy(env.reset()) TensorDict( fields={ action: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False), done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), intermediate: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.float32, is_shared=False), is_init: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), next: TensorDict( fields={ rs_c: Tensor(shape=torch.Size([1, 64]), device=cpu, dtype=torch.float32, is_shared=False), rs_h: Tensor(shape=torch.Size([1, 64]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=cpu, is_shared=False), observation: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)}, terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([]), device=cpu, is_shared=False)
- forward(tensordict: TensorDictBase = None)[source]¶
定义每次调用时执行的计算。
应被所有子类覆盖。
注意
尽管前向传播的实现需要在该函数中定义,但之后应调用
Module
实例而不是该函数,因为前者负责运行已注册的钩子,而后者会静默地忽略它们。
- make_cudnn_based() LSTMModule [source]¶
将 LSTM 层转换为基于 CuDNN 的版本。
- 返回值:
自身
- make_python_based() LSTMModule [source]¶
将 LSTM 层转换为基于 Python 的版本。
- 返回值:
自身
- make_tensordict_primer()[source]¶
为环境创建一个 tensordict primer。
一个
TensorDictPrimer
对象将确保策略在 rollout 执行期间感知补充输入和输出(循环状态)。这样,数据可以在不同进程之间共享并得到妥善处理。当使用批处理环境(例如
ParallelEnv
)时,transform 可以在单个环境实例级别使用(即,一批内部设置了 tensordict primer 的转换环境),也可以在批处理环境实例级别使用(即,一个转换过的常规环境批次)。不在环境中包含
TensorDictPrimer
可能导致行为定义不清,例如在并行设置中,一步涉及将新的循环状态从"next"
复制到根 tensordict,而 meth:~torchrl.EnvBase.step_mdp 方法将无法完成此操作,因为循环状态未在环境规范中注册。参阅
torchrl.modules.utils.get_primers_from_module()
,了解生成给定模块所有 primer 的方法。示例
>>> from torchrl.collectors import SyncDataCollector >>> from torchrl.envs import TransformedEnv, InitTracker >>> from torchrl.envs import GymEnv >>> from torchrl.modules import MLP, LSTMModule >>> from torch import nn >>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod >>> >>> env = TransformedEnv(GymEnv("Pendulum-v1"), InitTracker()) >>> lstm_module = LSTMModule( ... input_size=env.observation_spec["observation"].shape[-1], ... hidden_size=64, ... in_keys=["observation", "rs_h", "rs_c"], ... out_keys=["intermediate", ("next", "rs_h"), ("next", "rs_c")]) >>> mlp = MLP(num_cells=[64], out_features=1) >>> policy = Seq(lstm_module, Mod(mlp, in_keys=["intermediate"], out_keys=["action"])) >>> policy(env.reset()) >>> env = env.append_transform(lstm_module.make_tensordict_primer()) >>> data_collector = SyncDataCollector( ... env, ... policy, ... frames_per_batch=10 ... ) >>> for data in data_collector: ... print(data) ... break
- set_recurrent_mode(mode: bool = True)[source]¶
[已弃用 - 请改用
torchrl.modules.set_recurrent_mode
上下文管理器] 返回模块的新副本,该副本共享相同的 lstm 模型,但具有不同的recurrent_mode
属性(如果不同)。创建副本是为了使模块可以在代码的不同部分(推断 vs 训练)以不同的行为方式使用。
示例
>>> from torchrl.envs import TransformedEnv, InitTracker, step_mdp >>> from torchrl.envs import GymEnv >>> from torchrl.modules import MLP >>> from tensordict import TensorDict >>> from torch import nn >>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod >>> env = TransformedEnv(GymEnv("Pendulum-v1"), InitTracker()) >>> lstm = nn.LSTM(input_size=env.observation_spec["observation"].shape[-1], hidden_size=64, batch_first=True) >>> lstm_module = LSTMModule(lstm=lstm, in_keys=["observation", "hidden0", "hidden1"], out_keys=["intermediate", ("next", "hidden0"), ("next", "hidden1")]) >>> mlp = MLP(num_cells=[64], out_features=1) >>> # building two policies with different behaviors: >>> policy_inference = Seq(lstm_module, Mod(mlp, in_keys=["intermediate"], out_keys=["action"])) >>> policy_training = Seq(lstm_module.set_recurrent_mode(True), Mod(mlp, in_keys=["intermediate"], out_keys=["action"])) >>> traj_td = env.rollout(3) # some random temporal data >>> traj_td = policy_training(traj_td) >>> # let's check that both return the same results >>> td_inf = TensorDict(batch_size=traj_td.shape[:-1]) >>> for td in traj_td.unbind(-1): ... td_inf = td_inf.update(td.select("is_init", "observation", ("next", "observation"))) ... td_inf = policy_inference(td_inf) ... td_inf = step_mdp(td_inf) ... >>> torch.testing.assert_close(td_inf["hidden0"], traj_td[..., -1]["next", "hidden0"])