快捷方式

next_state_value

torchrl.objectives.next_state_value(tensordict: TensorDictBase, operator: Optional[TensorDictModule] = None, next_val_key: str = 'state_action_value', gamma: float = 0.99, pred_next_val: Optional[Tensor] = None, **kwargs)[source]

计算下一个状态值(无梯度)以计算目标值。

目标值通常用于计算距离损失(例如 MSE)

L = Sum[ (q_value - target_value)^2 ]

目标值计算如下

r + gamma ** n_steps_to_next * value_next_state

如果奖励是即时奖励,则 n_steps_to_next=1。如果使用 N 步奖励,则从输入的 tensordict 中收集 n_steps_to_next。

参数:
  • tensordict (TensorDictBase) – 包含 reward 和 done 键(以及 n-steps 奖励的 n_steps_to_next 键)的 Tensordict。

  • operator (ProbabilisticTDModule, optional) – 值函数算子。调用时应在输入的 tensordict 中写入一个 ‘next_val_key’ 键值对。如果提供了 pred_next_val,则无需提供此参数。

  • next_val_key (str, optional) – 将写入下一个值的键。默认值: ‘state_action_value’

  • gamma (float, optional) – 回报折扣率。默认值: 0.99

  • pred_next_val (Tensor, optional) – 如果未使用算子计算,则可以提供下一个状态值。

返回值:

一个与输入的 tensordict 大小相同的 Tensor,包含预测的状态值。


© 版权所有 2022, Meta。

使用 Sphinx 构建,主题由 Read the Docs 提供。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得解答

查看资源