ReplayBufferTrainer¶
- class torchrl.trainers.ReplayBufferTrainer(replay_buffer: TensorDictReplayBuffer, batch_size: Optional[int] = None, memmap: bool = False, device: DEVICE_TYPING | None = None, flatten_tensordicts: bool = False, max_dims: Optional[Sequence[int]] = None)[source]¶
回放缓冲区钩子提供者。
- 参数:
replay_buffer (TensorDictReplayBuffer) – 要使用的回放缓冲区。
batch_size (int, optional) – 从最新采集或回放缓冲区中采样数据时的批次大小。如果未提供,将使用回放缓冲区的批次大小(对于批次大小不变的情况,这是首选选项)。
memmap (bool, optional) – 如果为
True
,则创建一个 memmap tensordict。默认为False
。device (device, optional) – 必须放置样本的设备。默认为
None
。flatten_tensordicts (bool, optional) – 如果为
True
,tensordict 在传递给回放缓冲区之前将被展平(或等效地使用从收集器获得的有效掩码进行掩码)。否则,除了填充(见下文max_dims
参数)之外,不会执行其他变换。默认为False
。max_dims (int 序列, optional) – 如果
flatten_tensordicts
设置为 False,这将是一个列表,其长度等于提供的 tensordict 的批次大小,表示每个 tensordict 的最大尺寸。如果提供,此尺寸列表将用于填充 tensordict,并在将其传递给回放缓冲区之前使其形状匹配。如果没有最大值,应提供 -1。
示例
>>> rb_trainer = ReplayBufferTrainer(replay_buffer=replay_buffer, batch_size=N) >>> trainer.register_op("batch_process", rb_trainer.extend) >>> trainer.register_op("process_optim_batch", rb_trainer.sample) >>> trainer.register_op("post_loss", rb_trainer.update_priority)