快捷方式

TensorDictMaxValueWriter

class torchrl.data.replay_buffers.TensorDictMaxValueWriter(rank_key=None, reduction: str = 'sum', **kwargs)[源代码]

一个用于可组合回放缓冲区的 Writer 类,它根据某些排名键保留最顶部的元素。

参数:
  • rank_key (strstr 元组) – 用于对元素进行排名的键。默认为 ("next", "reward")

  • reduction (str) – 如果排名键有多个元素,则使用的方法。可以是 "max""min""mean""median""sum"

示例

>>> import torch
>>> from tensordict import TensorDict
>>> from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer, TensorDictMaxValueWriter
>>> from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
>>> rb = TensorDictReplayBuffer(
...     storage=LazyTensorStorage(1),
...     sampler=SamplerWithoutReplacement(),
...     batch_size=1,
...     writer=TensorDictMaxValueWriter(rank_key="key"),
... )
>>> td = TensorDict({
...     "key": torch.tensor(range(10)),
...     "obs": torch.tensor(range(10))
... }, batch_size=10)
>>> rb.extend(td)
>>> print(rb.sample().get("obs").item())
9
>>> td = TensorDict({
...     "key": torch.tensor(range(10, 20)),
...     "obs": torch.tensor(range(10, 20))
... }, batch_size=10)
>>> rb.extend(td)
>>> print(rb.sample().get("obs").item())
19
>>> td = TensorDict({
...     "key": torch.tensor(range(10)),
...     "obs": torch.tensor(range(10))
... }, batch_size=10)
>>> rb.extend(td)
>>> print(rb.sample().get("obs").item())
19

注意

此类与具有多个维度的存储不兼容。这并不意味着禁止存储轨迹,而是指存储的轨迹必须按轨迹为单位进行存储。以下是一些类有效和无效用法的示例。首先,一个扁平缓冲区,我们存储单个转换

>>> from torchrl.data import TensorStorage
>>> # Simplest use case: data comes in 1d and is stored as such
>>> data = TensorDict({
...     "obs": torch.zeros(10, 3),
...     "reward": torch.zeros(10, 1),
... }, batch_size=[10])
>>> rb = TensorDictReplayBuffer(
...     storage=LazyTensorStorage(max_size=100),
...     writer=TensorDictMaxValueWriter(rank_key="reward")
... )
>>> # We initialize the buffer: a total of 100 *transitions* can be stored
>>> rb.extend(data)
>>> # Samples 5 *transitions* at random
>>> sample = rb.sample(5)
>>> assert sample.shape == (5,)

其次,一个我们存储轨迹的缓冲区。最大信号在每个批次中聚合(例如,每个 rollout 的奖励被求和)

>>> # One can also store batches of data, each batch being a sub-trajectory
>>> env = ParallelEnv(2, lambda: GymEnv("Pendulum-v1"))
>>> # Get a batch of [2, 10] -- format is [Batch, Time]
>>> rollout = env.rollout(max_steps=10)
>>> rb = TensorDictReplayBuffer(
...     storage=LazyTensorStorage(max_size=100),
...     writer=TensorDictMaxValueWriter(rank_key="reward")
... )
>>> # We initialize the buffer: a total of 100 *trajectories* (!) can be stored
>>> rb.extend(rollout)
>>> # Sample 5 trajectories at random
>>> sample = rb.sample(5)
>>> assert sample.shape == (5, 10)

如果数据以批次形式出现,但需要扁平缓冲区,我们可以在扩展缓冲区之前简单地展平数据

>>> rb = TensorDictReplayBuffer(
...     storage=LazyTensorStorage(max_size=100),
...     writer=TensorDictMaxValueWriter(rank_key="reward")
... )
>>> # We initialize the buffer: a total of 100 *transitions* can be stored
>>> rb.extend(rollout.reshape(-1))
>>> # Sample 5 trajectories at random
>>> sample = rb.sample(5)
>>> assert sample.shape == (5,)

无法创建沿着时间维度扩展的缓冲区,这通常是使用带有轨迹批次的缓冲区的推荐方法。由于轨迹是重叠的,因此难以(如果不是不可能的话)聚合奖励值并进行比较。此构造函数无效(请注意 ndim 参数)

>>> rb = TensorDictReplayBuffer(
...     storage=LazyTensorStorage(max_size=100, ndim=2),  # Breaks!
...     writer=TensorDictMaxValueWriter(rank_key="reward")
... )
add(data: Any) int | torch.Tensor[源代码]

在适当的索引处插入单个数据元素,并返回该索引。

传递到此模块的数据中的 rank_key 应结构化为 []。如果它具有多个维度,它将使用 reduction 方法缩减为单个值。

extend(data: TensorDictBase) None[源代码]

在适当的索引处插入一系列数据点。

传递到此模块的数据中的 rank_key 应结构化为 [B]。如果它具有多个维度,它将使用 reduction 方法缩减为单个值。

get_insert_index(data: Any) int[源代码]

返回应插入数据的位置索引,如果数据不应插入,则返回 None

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源