PPOLoss¶
- class torchtune.rlhf.loss.PPOLoss(epsilon: float = 0.1, value_clip_range: float = 0.2, value_coeff: float = 0.1)[源代码]¶
近端策略优化 (PPO) 损失模块。此实现使用以下参考
https://arxiv.org/abs/1707.06347 公式 7
- 参数:
- forward(pi_old_logprobs: Tensor, pi_logprobs: Tensor, advantages: Tensor, phi_old_values: Tensor, phi_values: Tensor, returns: Tensor, padding_masks: Optional[Tensor] = None, value_padding_masks: Optional[Tensor] = None) Tuple[Tensor, Tensor, Tensor, Tensor, Tensor] [源代码]¶
PPO 损失模块的前向传递。
- 参数:
pi_old_logprobs (torch.Tensor) – 旧策略的对数概率。
pi_logprobs (torch.Tensor) – 当前策略的对数概率。
advantages (torch.Tensor) – 优势值。
phi_old_values (torch.Tensor) – 旧值函数的值预测。
phi_values (torch.Tensor) – 当前值函数的值预测。
returns (torch.Tensor) – 返回值。
padding_masks (Optional[torch.Tensor]) – 与
pi_logprobs
形状相同的填充标记掩码,其中 True 表示相应的损失值应参与策略损失计算。value_padding_masks (Optional[torch.Tensor]) – 与
pi_logprobs
形状相同的填充标记掩码,其中 True 表示相应的损失值应参与值损失计算。
- 返回值:
- 包含五个张量的元组
loss: 总 PPO 损失。
policy_loss: 策略函数损失。
value_loss: 值函数损失。
ratios: 当前和旧策略概率之间的比率。
clipfrac: 被裁剪的比率的分数。
- 返回类型:
Tuple[torch.Tensor,torch.Tensor,torch.Tensor,torch.Tensor,torch.Tensor]