TensorDictMaxValueWriter¶
- class torchrl.data.replay_buffers.TensorDictMaxValueWriter(rank_key=None, reduction: str = 'sum', **kwargs)[source]¶
一个 Writer 类,用于可组合回放缓冲区,根据某个排序键保留最优元素。
- 参数:
rank_key (str or tuple of str) – 用于对元素进行排序的键。默认为
("next", "reward")
。reduction (str) – 如果排序键包含多个元素,则使用的归约方法。可以是
"max"
、"min"
、"mean"
、"median"
或"sum"
。
示例
>>> import torch >>> from tensordict import TensorDict >>> from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer, TensorDictMaxValueWriter >>> from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement >>> rb = TensorDictReplayBuffer( ... storage=LazyTensorStorage(1), ... sampler=SamplerWithoutReplacement(), ... batch_size=1, ... writer=TensorDictMaxValueWriter(rank_key="key"), ... ) >>> td = TensorDict({ ... "key": torch.tensor(range(10)), ... "obs": torch.tensor(range(10)) ... }, batch_size=10) >>> rb.extend(td) >>> print(rb.sample().get("obs").item()) 9 >>> td = TensorDict({ ... "key": torch.tensor(range(10, 20)), ... "obs": torch.tensor(range(10, 20)) ... }, batch_size=10) >>> rb.extend(td) >>> print(rb.sample().get("obs").item()) 19 >>> td = TensorDict({ ... "key": torch.tensor(range(10)), ... "obs": torch.tensor(range(10)) ... }, batch_size=10) >>> rb.extend(td) >>> print(rb.sample().get("obs").item()) 19
注意
这个类与多维存储不兼容。这并不意味着禁止存储轨迹,而是存储的轨迹必须按每个轨迹的基础存储。以下是一些该类的有效和无效用法示例。首先,一个扁平缓冲区,我们在其中存储单个转换
>>> from torchrl.data import TensorStorage >>> # Simplest use case: data comes in 1d and is stored as such >>> data = TensorDict({ ... "obs": torch.zeros(10, 3), ... "reward": torch.zeros(10, 1), ... }, batch_size=[10]) >>> rb = TensorDictReplayBuffer( ... storage=LazyTensorStorage(max_size=100), ... writer=TensorDictMaxValueWriter(rank_key="reward") ... ) >>> # We initialize the buffer: a total of 100 *transitions* can be stored >>> rb.extend(data) >>> # Samples 5 *transitions* at random >>> sample = rb.sample(5) >>> assert sample.shape == (5,)
其次,一个存储轨迹的缓冲区。最大信号在每个批次中聚合(例如,每个 rolluot 的奖励被求和)
>>> # One can also store batches of data, each batch being a sub-trajectory >>> env = ParallelEnv(2, lambda: GymEnv("Pendulum-v1")) >>> # Get a batch of [2, 10] -- format is [Batch, Time] >>> rollout = env.rollout(max_steps=10) >>> rb = TensorDictReplayBuffer( ... storage=LazyTensorStorage(max_size=100), ... writer=TensorDictMaxValueWriter(rank_key="reward") ... ) >>> # We initialize the buffer: a total of 100 *trajectories* (!) can be stored >>> rb.extend(rollout) >>> # Sample 5 trajectories at random >>> sample = rb.sample(5) >>> assert sample.shape == (5, 10)
如果数据是批量形式的,但需要扁平缓冲区,我们可以在扩展缓冲区之前简单地扁平化数据
>>> rb = TensorDictReplayBuffer( ... storage=LazyTensorStorage(max_size=100), ... writer=TensorDictMaxValueWriter(rank_key="reward") ... ) >>> # We initialize the buffer: a total of 100 *transitions* can be stored >>> rb.extend(rollout.reshape(-1)) >>> # Sample 5 trajectories at random >>> sample = rb.sample(5) >>> assert sample.shape == (5,)
无法创建一个沿时间维度扩展的缓冲区,这通常是使用批量轨迹缓冲区时推荐的方式。由于轨迹是重叠的,聚合奖励值并进行比较很困难,甚至不可能。这个构造函数无效(注意 ndim 参数)
>>> rb = TensorDictReplayBuffer( ... storage=LazyTensorStorage(max_size=100, ndim=2), # Breaks! ... writer=TensorDictMaxValueWriter(rank_key="reward") ... )
- add(data: Any) int | torch.Tensor [source]¶
在适当的索引处插入单个数据元素,并返回该索引。
传递给此模块的数据中的
rank_key
应被构造为[]
。如果它具有更多维度,将使用reduction
方法将其归约到单个值。