快捷方式

ReplayBufferTrainer

class torchrl.trainers.ReplayBufferTrainer(replay_buffer: TensorDictReplayBuffer, batch_size: Optional[int] = None, memmap: bool = False, device: Union[device, str, int] = 'cpu', flatten_tensordicts: bool = False, max_dims: Optional[Sequence[int]] = None)[source]

回放缓冲区钩子提供程序。

参数:
  • replay_buffer (TensorDictReplayBuffer) – 要使用的回放缓冲区。

  • batch_size (int, 可选) – 从最新集合或回放缓冲区采样数据时的批次大小。如果未提供,则将使用回放缓冲区的批次大小(建议选项,用于批次大小不变的情况)。

  • memmap (bool, 可选) – 如果True,则会创建一个内存映射张量字典。默认为 False。

  • device (设备, 可选) – 必须放置样本的设备。默认为 cpu。

  • flatten_tensordicts (bool, 可选) – 如果True,则在传递到回放缓冲区之前,将展平张量字典(或等效地使用从收集器获得的有效掩码进行掩码)。否则,除了填充之外,不会进行任何转换(请参阅下面的max_dims参数)。默认为 True

  • max_dims (整数序列, 可选) – 如果flatten_tensordicts设置为 False,则这将是提供的张量字典的批次大小长度的列表,表示每个张量字典的最大大小。如果提供,则此大小列表将用于填充张量字典并使其形状匹配,然后将其传递到回放缓冲区。如果没有最大值,则应提供 -1 值。

示例

>>> rb_trainer = ReplayBufferTrainer(replay_buffer=replay_buffer, batch_size=N)
>>> trainer.register_op("batch_process", rb_trainer.extend)
>>> trainer.register_op("process_optim_batch", rb_trainer.sample)
>>> trainer.register_op("post_loss", rb_trainer.update_priority)
register(trainer: Trainer, name: str = 'replay_buffer')[source]

在训练器中的默认位置注册钩子。

参数:
  • trainer (Trainer) – 必须注册钩子的训练器。

  • name (str) – 钩子的名称。

注意

要在默认位置以外的位置注册钩子,请使用register_op()

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源