PrioritizedSampler¶
- class torchrl.data.replay_buffers.PrioritizedSampler(max_capacity: int, alpha: float, beta: float, eps: float = 1e-08, dtype: dtype = torch.float32, reduction: str = 'max', max_priority_within_buffer: bool = False)[source]¶
回放缓冲区的优先级采样器。
在“Schaul, T.; Quan, J.; Antonoglou, I.; and Silver, D. 2015. 优先体验回放。”中提出。 (https://arxiv.org/abs/1511.05952)
- 参数:
max_capacity (int) – 缓冲区的最大容量。
alpha (float) – 指数 α 决定使用多少优先级,α = 0 对应于均匀情况。
beta (float) – 重要性采样负指数。
eps (float, 可选) – 添加到优先级的增量,以确保缓冲区不包含空优先级。默认值为 1e-8。
reduction (str, 可选) – 多维张量字典(即存储的轨迹)的减少方法。可以是“max”、“min”、“median”或“mean”之一。
max_priority_within_buffer (bool, 可选) – 如果
True
,则在缓冲区内跟踪最大优先级。当False
时,最大优先级跟踪自采样器实例化以来的最大值。
示例
>>> from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage, PrioritizedSampler >>> from tensordict import TensorDict >>> rb = ReplayBuffer(storage=LazyTensorStorage(10), sampler=PrioritizedSampler(max_capacity=10, alpha=1.0, beta=1.0)) >>> priority = torch.tensor([0, 1000]) >>> data_0 = TensorDict({"reward": 0, "obs": [0], "action": [0], "priority": priority[0]}, []) >>> data_1 = TensorDict({"reward": 1, "obs": [1], "action": [2], "priority": priority[1]}, []) >>> rb.add(data_0) >>> rb.add(data_1) >>> rb.update_priority(torch.tensor([0, 1]), priority=priority) >>> sample, info = rb.sample(10, return_info=True) >>> print(sample) TensorDict( fields={ action: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.int64, is_shared=False), obs: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.int64, is_shared=False), priority: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.int64, is_shared=False), reward: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([10]), device=cpu, is_shared=False) >>> print(info) {'_weight': array([1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11], dtype=float32), 'index': array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])}
注意
使用
TensorDictReplayBuffer
可以简化更新优先级的过程>>> from torchrl.data.replay_buffers import TensorDictReplayBuffer as TDRB, LazyTensorStorage, PrioritizedSampler >>> from tensordict import TensorDict >>> rb = TDRB( ... storage=LazyTensorStorage(10), ... sampler=PrioritizedSampler(max_capacity=10, alpha=1.0, beta=1.0), ... priority_key="priority", # This kwarg isn't present in regular RBs ... ) >>> priority = torch.tensor([0, 1000]) >>> data_0 = TensorDict({"reward": 0, "obs": [0], "action": [0], "priority": priority[0]}, []) >>> data_1 = TensorDict({"reward": 1, "obs": [1], "action": [2], "priority": priority[1]}, []) >>> data = torch.stack([data_0, data_1]) >>> rb.extend(data) >>> rb.update_priority(data) # Reads the "priority" key as indicated in the constructor >>> sample, info = rb.sample(10, return_info=True) >>> print(sample['index']) # The index is packed with the tensordict tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
- update_priority(index: Union[int, torch.Tensor], priority: Union[float, torch.Tensor], *, storage: TensorStorage | None = None) None [source]¶
更新索引指向的数据的优先级。
- 参数:
index (int 或 torch.Tensor) – 要更新的优先级的索引。
priority (数字 或 torch.Tensor) – 索引元素的新优先级。
- 关键字参数:
storage (Storage, 可选) – 用于将 Nd 索引大小映射到 sum_tree 和 min_tree 的 1d 大小的存储。仅在
index.ndim > 2
时需要。