快捷方式

BurnInTransform

class torchrl.envs.transforms.BurnInTransform(modules: Sequence[TensorDictModuleBase], burn_in: int, out_keys: Sequence[NestedKey] | None = None)[source]

用于部分预热数据序列的转换。

此转换对于在不可用时获取最新的循环状态很有用。它会在采样后的顺序数据切片的时间维度上预热一定数量的步骤,并返回剩余的数据序列,其中预热数据位于其初始时间步骤。此转换旨在用作回放缓冲区转换,而不是环境转换。

参数:
  • modules (TensorDictModule 的序列) – 用于预热数据序列的模块列表。

  • burn_in (int) – 要预热的步数。

  • out_keys (NestedKey 的序列, 可选) – 目标键。默认为

  • ` (所有指向下一个时间步骤的模块 out_keys (例如,如果) –

  • ("next"

  • module). ("hidden")` 是模块的 out_keys 的一部分) –

注意

此转换期望输入 TensorDicts,其最后一个维度是时间维度。它还假设所有提供的模块都可以处理顺序数据。

示例

>>> import torch
>>> from tensordict import TensorDict
>>> from torchrl.envs.transforms import BurnInTransform
>>> from torchrl.modules import GRUModule
>>> gru_module = GRUModule(
...     input_size=10,
...     hidden_size=10,
...     in_keys=["observation", "hidden"],
...     out_keys=["intermediate", ("next", "hidden")],
... ).set_recurrent_mode(True)
>>> burn_in_transform = BurnInTransform(
...     modules=[gru_module],
...     burn_in=5,
... )
>>> td = TensorDict({
...     "observation": torch.randn(2, 10, 10),
...      "hidden": torch.randn(2, 10, gru_module.gru.num_layers, 10),
...      "is_init": torch.zeros(2, 10, 1),
... }, batch_size=[2, 10])
>>> td = burn_in_transform(td)
>>> td.shape
torch.Size([2, 5])
>>> td.get("hidden").abs().sum()
tensor(86.3008)
>>> from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
>>> buffer = TensorDictReplayBuffer(
...     storage=LazyMemmapStorage(2),
...     batch_size=1,
... )
>>> buffer.append_transform(burn_in_transform)
>>> td = TensorDict({
...     "observation": torch.randn(2, 10, 10),
...      "hidden": torch.randn(2, 10, gru_module.gru.num_layers, 10),
...      "is_init": torch.zeros(2, 10, 1),
... }, batch_size=[2, 10])
>>> buffer.extend(td)
>>> td = buffer.sample(1)
>>> td.shape
torch.Size([1, 5])
>>> td.get("hidden").abs().sum()
tensor(37.0344)
forward(tensordict: TensorDictBase) TensorDictBase[source]

读取输入 tensordict,并针对选定的键应用转换。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获得针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源