快捷方式

LogValidationReward

class torchrl.trainers.LogValidationReward(*, record_interval: int, record_frames: int, frame_skip: int = 1, policy_exploration: TensorDictModule, environment: Optional[EnvBase] = None, exploration_type: InteractionType = InteractionType.RANDOM, log_keys: Optional[List[Union[str, Tuple[str]]]] = None, out_keys: Optional[Dict[Union[str, Tuple[str]], str]] = None, suffix: Optional[str] = None, log_pbar: bool = False, recorder: Optional[EnvBase] = None)[source]

Trainer 的记录器钩子。

参数:
  • record_interval (int) – 两次调用记录器进行测试之间的总优化步数。

  • record_frames (int) – 在测试期间要记录的帧数。

  • frame_skip (int) – 环境中使用的帧跳数。告知训练器每次迭代跳过的帧数非常重要,否则帧数统计可能会被低估。对于日志记录,此参数对于标准化奖励也很重要。最后,为了比较不同 frame_skip 的不同运行结果,必须对帧数和奖励进行标准化。默认为 1

  • policy_exploration (ProbabilisticTDModule) –

    一个用于

    1. 更新探索噪声调度的策略实例;

    2. 在记录器上测试策略。

    鉴于此实例既用于探索又用于呈现策略性能,因此应该可以通过调用 set_exploration_type(ExplorationType.DETERMINISTIC) 上下文管理器来关闭探索行为。

  • environment (EnvBase) – 用于测试的环境实例。

  • exploration_type (ExplorationType, optional) – 用于策略的探索模式。默认情况下,不使用探索,使用的值为 ExplorationType.DETERMINISTIC。设置为 ExplorationType.RANDOM 可启用探索

  • log_keys (sequence of str or tuples or str, optional) – 在 tensordict 中读取用于日志记录的键。默认为 [("next", "reward")]

  • out_keys (Dict[str, str], optional) – 一个字典,将 log_keys 映射到它们在日志中的名称。默认为 {("next", "reward"): "r_evaluation"}

  • suffix (str, optional) – 要录制视频的后缀。

  • log_pbar (bool, optional) – 如果为 True,则奖励值将记录在进度条上。默认为 False

register(trainer: Trainer, name: str = 'recorder')[source]

在默认位置将钩子注册到训练器中。

参数:
  • trainer (Trainer) – 必须注册钩子的训练器。

  • name (str) – 钩子的名称。

注意

要在非默认位置注册钩子,请使用 register_op()

文档

查阅 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发者资源并获取问题解答

查看资源