BurnInTransform¶
- class torchrl.envs.transforms.BurnInTransform(modules: Sequence[TensorDictModuleBase], burn_in: int, out_keys: Sequence[NestedKey] | None = None)[source]¶
用于部分预热(burn-in)数据序列的 Transform。
当无法获得最新的循环状态时,此 transform 非常有用。它从采样到的序列数据切片中沿时间维度预热若干步,并返回剩余的数据序列,其中预热的数据位于其初始时间步。此 transform 旨在用作回放缓冲区的 transform,而非环境 transform。
- 参数:
modules (TensorDictModule 的序列) – 用于预热数据序列的模块列表。
burn_in (int) – 要预热的时间步数。
out_keys (NestedKey 的序列, 可选) – 目标键。默认为
` (所有指向下一时间步的模块输出键 (例如,如果一个模块的输出键包含 "hidden") –
("next" –
module). ("hidden"),则默认为 "hidden") –
注意
此 transform 期望输入的 TensorDicts 的最后一个维度是时间维度。它还假设所有提供的模块都能处理序列数据。
示例
>>> import torch >>> from tensordict import TensorDict >>> from torchrl.envs.transforms import BurnInTransform >>> from torchrl.modules import GRUModule >>> gru_module = GRUModule( ... input_size=10, ... hidden_size=10, ... in_keys=["observation", "hidden"], ... out_keys=["intermediate", ("next", "hidden")], ... default_recurrent_mode=True, ... ) >>> burn_in_transform = BurnInTransform( ... modules=[gru_module], ... burn_in=5, ... ) >>> td = TensorDict({ ... "observation": torch.randn(2, 10, 10), ... "hidden": torch.randn(2, 10, gru_module.gru.num_layers, 10), ... "is_init": torch.zeros(2, 10, 1), ... }, batch_size=[2, 10]) >>> td = burn_in_transform(td) >>> td.shape torch.Size([2, 5]) >>> td.get("hidden").abs().sum() tensor(86.3008)
>>> from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer >>> buffer = TensorDictReplayBuffer( ... storage=LazyMemmapStorage(2), ... batch_size=1, ... ) >>> buffer.append_transform(burn_in_transform) >>> td = TensorDict({ ... "observation": torch.randn(2, 10, 10), ... "hidden": torch.randn(2, 10, gru_module.gru.num_layers, 10), ... "is_init": torch.zeros(2, 10, 1), ... }, batch_size=[2, 10]) >>> buffer.extend(td) >>> td = buffer.sample(1) >>> td.shape torch.Size([1, 5]) >>> td.get("hidden").abs().sum() tensor(37.0344)