快捷方式

优先级切片采样器

class torchrl.data.replay_buffers.PrioritizedSliceSampler(max_capacity: int, alpha: float, beta: float, eps: float = 1e-08, dtype: torch.dtype = torch.float32, reduction: str = 'max', *, num_slices: int = None, slice_len: int = None, end_key: NestedKey | None = None, traj_key: NestedKey | None = None, ends: torch.Tensor | None = None, trajectories: torch.Tensor | None = None, cache_values: bool = False, truncated_key: NestedKey | None = ('next', 'truncated'), strict_length: bool = True, compile: bool | dict = False, span: bool | int | Tuple[bool | int, bool | int] = False, max_priority_within_buffer: bool = False)[source]

根据开始和停止信号,使用优先级采样沿第一个维度对数据切片进行采样。

此类根据“Schaul,T.;Quan,J.;Antonoglou,I.;以及 Silver,D. 2015.”中提出的优先级权重进行有放回的子轨迹采样。

优先经验回放。” (https://arxiv.org/abs/1511.05952)

有关更多信息,请参阅 SliceSamplerPrioritizedSampler

警告

PrioritizedSliceSampler 将查看单个转换的优先级并相应地采样起始点。这意味着如果低优先级的转换紧随高优先级的转换,则它们也可能出现在样本中,而如果高优先级的转换靠近轨迹的末尾并且不能用作起始点,则可能永远不会被采样。目前,用户有责任使用 update_priority() 聚合轨迹中各个项目的优先级。

参数:
  • alpha (float) – 指数 α 决定使用多少优先级,其中 α = 0 对应于均匀情况。

  • beta (float) – 重要性采样负指数。

  • eps (float, 可选) – 添加到优先级的增量,以确保缓冲区不包含空优先级。默认为 1e-8。

  • reduction (str, 可选) – 多维张量字典(即存储的轨迹)的归约方法。可以是“max”、“min”、“median”或“mean”之一。

关键字参数:
  • num_slices (int) – 要采样的切片数量。批大小必须大于或等于num_slices参数。与slice_len互斥。

  • slice_len (int) – 要采样的切片的长度。批大小必须大于或等于slice_len参数,并且可以被其整除。与num_slices互斥。

  • end_key (NestedKey, 可选) – 指示轨迹(或情节)结束的键。默认为("next", "done")

  • traj_key (NestedKey, 可选) – 指示轨迹的键。默认为"episode"(在 TorchRL 中的数据集中普遍使用)。

  • ends (torch.Tensor, 可选) – 一个 1d 布尔张量,包含运行结束信号。在end_keytraj_key获取成本很高,或者此信号随时可用时使用。必须与cache_values=True一起使用,并且不能与end_keytraj_key结合使用。

  • trajectories (torch.Tensor, 可选) – 一个 1d 整数张量,包含运行 ID。在end_keytraj_key获取成本很高,或者此信号随时可用时使用。必须与cache_values=True一起使用,并且不能与end_keytraj_key结合使用。

  • cache_values (bool, 可选) –

    用于静态数据集。将缓存轨迹的开始和结束信号。即使轨迹索引在调用extend期间发生变化,也可以安全地使用此操作,因为此操作将清除缓存。

    警告

    cache_values=True如果采样器与由另一个缓冲区扩展的存储一起使用,则不起作用。例如

    >>> buffer0 = ReplayBuffer(storage=storage,
    ...     sampler=SliceSampler(num_slices=8, cache_values=True),
    ...     writer=ImmutableWriter())
    >>> buffer1 = ReplayBuffer(storage=storage,
    ...     sampler=other_sampler)
    >>> # Wrong! Does not erase the buffer from the sampler of buffer0
    >>> buffer1.extend(data)
    

    警告

    cache_values=True如果缓冲区在进程之间共享,并且一个进程负责写入,一个进程负责采样,则不会按预期工作,因为清除缓存只能在本地完成。

  • truncated_key (NestedKey, 可选) – 如果不为None,则此参数指示在输出数据中应写入截断信号的位置。这用于指示值估计器提供的轨迹在哪里断开。默认为("next", "truncated")。此功能仅适用于TensorDictReplayBuffer实例(否则,截断键将在sample()方法返回的信息字典中返回)。

  • strict_length (bool, 可选) – 如果为False,则允许长度短于slice_len(或batch_size // num_slices)的轨迹出现在批处理中。如果为True,则将过滤掉短于所需长度的轨迹。请注意,这可能导致有效的batch_size短于请求的批大小!可以使用split_trajectories()拆分轨迹。默认为True

  • compile (boolkwargs 字典, 可选) – 如果为True,则sample()方法的瓶颈将使用compile()进行编译。还可以使用此参数将关键字参数传递给 torch.compile。默认为False

  • span (bool, int, Tuple[bool | int, bool | int], 可选) – 如果提供,则采样的轨迹将跨越左侧和/或右侧。这意味着可能提供的元素少于所需元素。布尔值表示每个轨迹至少会采样一个元素。整数i表示每个采样轨迹至少会收集slice_len - i个样本。使用元组可以对存储轨迹的开头(左侧)和结尾(右侧)的跨度进行细粒度控制。

  • max_priority_within_buffer (bool, 可选) – 如果为True,则在缓冲区内跟踪最大优先级。当False时,最大优先级跟踪自采样器实例化以来的最大值。默认为False

示例

>>> import torch
>>> from torchrl.data.replay_buffers import TensorDictReplayBuffer, LazyMemmapStorage, PrioritizedSliceSampler
>>> from tensordict import TensorDict
>>> sampler = PrioritizedSliceSampler(max_capacity=9, num_slices=3, alpha=0.7, beta=0.9)
>>> rb = TensorDictReplayBuffer(storage=LazyMemmapStorage(9), sampler=sampler, batch_size=6)
>>> data = TensorDict(
...     {
...         "observation": torch.randn(9,16),
...         "action": torch.randn(9, 1),
...         "episode": torch.tensor([0,0,0,1,1,1,2,2,2], dtype=torch.long),
...         "steps": torch.tensor([0,1,2,0,1,2,0,1,2], dtype=torch.long),
...         ("next", "observation"): torch.randn(9,16),
...         ("next", "reward"): torch.randn(9,1),
...         ("next", "done"): torch.tensor([0,0,1,0,0,1,0,0,1], dtype=torch.bool).unsqueeze(1),
...     },
...     batch_size=[9],
... )
>>> rb.extend(data)
>>> sample, info = rb.sample(return_info=True)
>>> print("episode", sample["episode"].tolist())
episode [2, 2, 2, 2, 1, 1]
>>> print("steps", sample["steps"].tolist())
steps [1, 2, 0, 1, 1, 2]
>>> print("weight", info["_weight"].tolist())
weight [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
>>> priority = torch.tensor([0,3,3,0,0,0,1,1,1])
>>> rb.update_priority(torch.arange(0,9,1), priority=priority)
>>> sample, info = rb.sample(return_info=True)
>>> print("episode", sample["episode"].tolist())
episode [2, 2, 2, 2, 2, 2]
>>> print("steps", sample["steps"].tolist())
steps [1, 2, 0, 1, 0, 1]
>>> print("weight", info["_weight"].tolist())
weight [9.120110917137936e-06, 9.120110917137936e-06, 9.120110917137936e-06, 9.120110917137936e-06, 9.120110917137936e-06, 9.120110917137936e-06]
update_priority(index: Union[int, torch.Tensor], priority: Union[float, torch.Tensor], *, storage: TensorStorage | None = None) None

更新索引指向的数据的优先级。

参数:
  • index (inttorch.Tensor) – 要更新的优先级的索引。

  • priority (Numbertorch.Tensor) – 已索引元素的新优先级。

关键字参数:

storage (Storage, 可选) – 用于将 Nd 索引大小映射到 sum_tree 和 min_tree 的 1d 大小的存储。仅在index.ndim > 2时需要。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源