TensorDictMaxValueWriter¶
- class torchrl.data.replay_buffers.TensorDictMaxValueWriter(rank_key=None, reduction: str = 'sum', **kwargs)[源代码]¶
一个用于可组合回放缓冲区的 Writer 类,它根据某些排名键保留最顶部的元素。
- 参数:
rank_key (str 或 str 元组) – 用于对元素进行排名的键。默认为
("next", "reward")
。reduction (str) – 如果排名键有多个元素,则使用的方法。可以是
"max"
、"min"
、"mean"
、"median"
或"sum"
。
示例
>>> import torch >>> from tensordict import TensorDict >>> from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer, TensorDictMaxValueWriter >>> from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement >>> rb = TensorDictReplayBuffer( ... storage=LazyTensorStorage(1), ... sampler=SamplerWithoutReplacement(), ... batch_size=1, ... writer=TensorDictMaxValueWriter(rank_key="key"), ... ) >>> td = TensorDict({ ... "key": torch.tensor(range(10)), ... "obs": torch.tensor(range(10)) ... }, batch_size=10) >>> rb.extend(td) >>> print(rb.sample().get("obs").item()) 9 >>> td = TensorDict({ ... "key": torch.tensor(range(10, 20)), ... "obs": torch.tensor(range(10, 20)) ... }, batch_size=10) >>> rb.extend(td) >>> print(rb.sample().get("obs").item()) 19 >>> td = TensorDict({ ... "key": torch.tensor(range(10)), ... "obs": torch.tensor(range(10)) ... }, batch_size=10) >>> rb.extend(td) >>> print(rb.sample().get("obs").item()) 19
注意
此类与具有多个维度的存储不兼容。这并不意味着禁止存储轨迹,而是指存储的轨迹必须按轨迹为单位进行存储。以下是一些类有效和无效用法的示例。首先,一个扁平缓冲区,我们存储单个转换
>>> from torchrl.data import TensorStorage >>> # Simplest use case: data comes in 1d and is stored as such >>> data = TensorDict({ ... "obs": torch.zeros(10, 3), ... "reward": torch.zeros(10, 1), ... }, batch_size=[10]) >>> rb = TensorDictReplayBuffer( ... storage=LazyTensorStorage(max_size=100), ... writer=TensorDictMaxValueWriter(rank_key="reward") ... ) >>> # We initialize the buffer: a total of 100 *transitions* can be stored >>> rb.extend(data) >>> # Samples 5 *transitions* at random >>> sample = rb.sample(5) >>> assert sample.shape == (5,)
其次,一个我们存储轨迹的缓冲区。最大信号在每个批次中聚合(例如,每个 rollout 的奖励被求和)
>>> # One can also store batches of data, each batch being a sub-trajectory >>> env = ParallelEnv(2, lambda: GymEnv("Pendulum-v1")) >>> # Get a batch of [2, 10] -- format is [Batch, Time] >>> rollout = env.rollout(max_steps=10) >>> rb = TensorDictReplayBuffer( ... storage=LazyTensorStorage(max_size=100), ... writer=TensorDictMaxValueWriter(rank_key="reward") ... ) >>> # We initialize the buffer: a total of 100 *trajectories* (!) can be stored >>> rb.extend(rollout) >>> # Sample 5 trajectories at random >>> sample = rb.sample(5) >>> assert sample.shape == (5, 10)
如果数据以批次形式出现,但需要扁平缓冲区,我们可以在扩展缓冲区之前简单地展平数据
>>> rb = TensorDictReplayBuffer( ... storage=LazyTensorStorage(max_size=100), ... writer=TensorDictMaxValueWriter(rank_key="reward") ... ) >>> # We initialize the buffer: a total of 100 *transitions* can be stored >>> rb.extend(rollout.reshape(-1)) >>> # Sample 5 trajectories at random >>> sample = rb.sample(5) >>> assert sample.shape == (5,)
无法创建沿着时间维度扩展的缓冲区,这通常是使用带有轨迹批次的缓冲区的推荐方法。由于轨迹是重叠的,因此难以(如果不是不可能的话)聚合奖励值并进行比较。此构造函数无效(请注意 ndim 参数)
>>> rb = TensorDictReplayBuffer( ... storage=LazyTensorStorage(max_size=100, ndim=2), # Breaks! ... writer=TensorDictMaxValueWriter(rank_key="reward") ... )
- add(data: Any) int | torch.Tensor [源代码]¶
在适当的索引处插入单个数据元素,并返回该索引。
传递到此模块的数据中的
rank_key
应结构化为 []。如果它具有多个维度,它将使用reduction
方法缩减为单个值。