PrioritizedSliceSampler¶
- class torchrl.data.replay_buffers.PrioritizedSliceSampler(max_capacity: int, alpha: float, beta: float, eps: float = 1e-08, dtype: dtype = torch.float32, reduction: str = 'max', *, num_slices: Optional[int] = None, slice_len: Optional[int] = None, end_key: Optional[NestedKey] = None, traj_key: Optional[NestedKey] = None, ends: Optional[Tensor] = None, trajectories: Optional[Tensor] = None, cache_values: bool = False, truncated_key: tensordict._nestedkey.NestedKey | None = ('next', 'truncated'), strict_length: bool = True, compile: bool | dict = False, span: Union[bool, int, Tuple[bool | int, bool | int]] = False, max_priority_within_buffer: bool = False)[source]¶
使用优先级采样,沿第一维度采样数据切片,给定起始和停止信号。
- 此类按照 “Schaul, T.; Quan, J.; Antonoglou, I.; and Silver, D. 2015.” 中提出的优先级权重替换采样子轨迹。
优先经验回放。” (https://arxiv.org/abs/1511.05952)
有关更多信息,请参阅
SliceSampler
和PrioritizedSampler
。警告
PrioritizedSliceSampler 将查看各个转换的优先级,并相应地采样起始点。这意味着如果优先级较低的转换跟随另一个优先级较高的转换,它们也可能会出现在样本中,而如果优先级较高但更接近轨迹末尾的转换不能用作起始点,则可能永远不会被采样。目前,用户有责任使用
update_priority()
聚合轨迹项目中各项的优先级。- 参数:
alpha (float) – 指数 α 决定了优先级的使用程度,α = 0 对应于均匀情况。
beta (float) – 重要性采样负指数。
eps (float, optional) – 添加到优先级的增量,以确保缓冲区不包含空优先级。默认为 1e-8。
reduction (str, optional) – 多维张量字典(即,存储的轨迹)的归约方法。可以是 “max”、“min”、“median” 或 “mean” 之一。
- 关键字参数:
num_slices (int) – 要采样的切片数量。批大小必须大于或等于
num_slices
参数。与slice_len
互斥。slice_len (int) – 要采样的切片长度。批大小必须大于或等于
slice_len
参数,并且可以被其整除。与num_slices
互斥。end_key (NestedKey, optional) – 指示轨迹(或 episode)结束的键。默认为
("next", "done")
。traj_key (NestedKey, optional) – 指示轨迹的键。默认为
"episode"
(TorchRL 中跨数据集常用)。ends (torch.Tensor, optional) – 包含运行结束信号的 1d 布尔张量。在
end_key
或traj_key
获取成本高昂时,或者当此信号容易获得时使用。必须与cache_values=True
一起使用,并且不能与end_key
或traj_key
结合使用。trajectories (torch.Tensor, optional) – 包含运行 id 的 1d 整数张量。在
end_key
或traj_key
获取成本高昂时,或者当此信号容易获得时使用。必须与cache_values=True
一起使用,并且不能与end_key
或traj_key
结合使用。cache_values (bool, optional) –
用于静态数据集。将缓存轨迹的起始和结束信号。即使轨迹索引在调用
extend
期间发生更改,也可以安全地使用此功能,因为此操作将擦除缓存。警告
如果采样器与由另一个缓冲区扩展的存储一起使用,
cache_values=True
将不起作用。例如>>> buffer0 = ReplayBuffer(storage=storage, ... sampler=SliceSampler(num_slices=8, cache_values=True), ... writer=ImmutableWriter()) >>> buffer1 = ReplayBuffer(storage=storage, ... sampler=other_sampler) >>> # Wrong! Does not erase the buffer from the sampler of buffer0 >>> buffer1.extend(data)
警告
如果缓冲区在进程之间共享,并且一个进程负责写入,另一个进程负责采样,则
cache_values=True
将无法按预期工作,因为擦除缓存只能在本地完成。truncated_key (NestedKey, optional) – 如果不是
None
,则此参数指示应将截断信号写入输出数据的位置。这用于向值估计器指示提供的轨迹在哪里中断。默认为("next", "truncated")
。此功能仅适用于TensorDictReplayBuffer
实例(否则,截断键在sample()
方法返回的 info 字典中返回)。strict_length (bool, optional) – 如果
False
,则允许长度小于 slice_len(或 batch_size // num_slices)的轨迹出现在批次中。如果True
,则将滤除短于要求的轨迹。请注意,这可能会导致有效 batch_size 小于要求的批大小!可以使用split_trajectories()
拆分轨迹。默认为True
。compile (bool or dict of kwargs, optional) – 如果
True
,则sample()
方法的瓶颈将使用compile()
进行编译。关键字参数也可以通过此参数传递给 torch.compile。默认为False
。span (bool, int, Tuple[bool | int, bool | int], optional) – 如果提供,则采样的轨迹将跨越左侧和/或右侧。这意味着提供的元素可能少于要求的元素。布尔值表示每个轨迹将采样至少一个元素。整数 i 表示每个采样的轨迹将收集至少 slice_len - i 个样本。使用元组可以精细控制左侧(存储轨迹的开头)和右侧(存储轨迹的结尾)的跨度。
max_priority_within_buffer (bool, optional) – 如果
True
,则在缓冲区内跟踪最大优先级。当False
时,最大优先级跟踪自采样器实例化以来的最大值。默认为False
。
示例
>>> import torch >>> from torchrl.data.replay_buffers import TensorDictReplayBuffer, LazyMemmapStorage, PrioritizedSliceSampler >>> from tensordict import TensorDict >>> sampler = PrioritizedSliceSampler(max_capacity=9, num_slices=3, alpha=0.7, beta=0.9) >>> rb = TensorDictReplayBuffer(storage=LazyMemmapStorage(9), sampler=sampler, batch_size=6) >>> data = TensorDict( ... { ... "observation": torch.randn(9,16), ... "action": torch.randn(9, 1), ... "episode": torch.tensor([0,0,0,1,1,1,2,2,2], dtype=torch.long), ... "steps": torch.tensor([0,1,2,0,1,2,0,1,2], dtype=torch.long), ... ("next", "observation"): torch.randn(9,16), ... ("next", "reward"): torch.randn(9,1), ... ("next", "done"): torch.tensor([0,0,1,0,0,1,0,0,1], dtype=torch.bool).unsqueeze(1), ... }, ... batch_size=[9], ... ) >>> rb.extend(data) >>> sample, info = rb.sample(return_info=True) >>> print("episode", sample["episode"].tolist()) episode [2, 2, 2, 2, 1, 1] >>> print("steps", sample["steps"].tolist()) steps [1, 2, 0, 1, 1, 2] >>> print("weight", info["_weight"].tolist()) weight [1.0, 1.0, 1.0, 1.0, 1.0, 1.0] >>> priority = torch.tensor([0,3,3,0,0,0,1,1,1]) >>> rb.update_priority(torch.arange(0,9,1), priority=priority) >>> sample, info = rb.sample(return_info=True) >>> print("episode", sample["episode"].tolist()) episode [2, 2, 2, 2, 2, 2] >>> print("steps", sample["steps"].tolist()) steps [1, 2, 0, 1, 0, 1] >>> print("weight", info["_weight"].tolist()) weight [9.120110917137936e-06, 9.120110917137936e-06, 9.120110917137936e-06, 9.120110917137936e-06, 9.120110917137936e-06, 9.120110917137936e-06]
- update_priority(index: Union[int, Tensor], priority: Union[float, Tensor], *, storage: torchrl.data.replay_buffers.storages.TensorStorage | None = None) None ¶
更新索引指向的数据的优先级。
- 参数:
index (int or torch.Tensor) – 要更新的优先级的索引。
priority (Number or torch.Tensor) – 索引元素的新优先级。
- 关键字参数:
storage (Storage, optional) – 用于将 Nd 索引大小映射到 sum_tree 和 min_tree 的 1d 大小的存储。仅当
index.ndim > 2
时才需要。