estimate_advantages¶
- torchtune.rlhf.estimate_advantages(values: Tensor, rewards: Tensor, gamma: float, lmbda: float, masks: Optional[Tensor] = None) Tuple[Tensor, Tensor] [源代码]¶
使用广义优势估计 https://arxiv.org/pdf/1506.02438.pdf 估计 PPO 算法的优势和回报。
- 参数:
values (torch.Tensor) – 每个状态的预测值。形状:
(b, response_len)
rewards (torch.Tensor) – 在每个时间步获得的奖励。形状:
(b, response_len)
gamma (float) – 折扣因子。
lmbda (float) – GAE-Lambda 参数。
masks (可选[torch.Tensor]) – 布尔掩码张量,其中 True 表示
values
中的对应值应参与平均值计算。默认值为 None。
- 返回:
- 包含估计的优势和回报的元组。
advantages (torch.Tensor):估计的优势。形状:
(b, response_len)
returns (torch.Tensor):估计的回报。形状:
(b, response_len)
- 返回类型:
Tuple[torch.Tensor, torch.Tensor]
- 符号
b:批次大小
response_len:模型响应长度