DPOLoss¶
- class torchtune.rlhf.loss.DPOLoss(beta: float = 0.1, label_smoothing: float = 0.0)[source]¶
直接偏好优化 (DPO) 损失模块:https://arxiv.org/abs/2305.18290 简单地说,从论文中来看
直观地说,DPO 更新会增加首选响应相对于非首选响应的相对对数概率,但它会合并一个动态的、特定于示例的重要性权重,从而阻止模型退化,我们发现这种情况在使用简单的概率比目标时会发生。
基于 HF 的 TRL 库中的实现:https://github.com/huggingface/trl/blob/5d1deb1445828cfd0e947cb3a7925b1c03a283fc/trl/trainer/dpo_trainer.py#L844
DPO 保留了与 PPO (https://arxiv.org/abs/2009.01325) 的相似性,它优化了策略 (语言) 模型以与人类偏好保持一致,并使用基线参考 (冻结的初始语言模型) 正则化损失函数,以防止过度拟合偏好数据集。它与 PPO 的区别在于,它直接使用标记的偏好数据优化策略模型,而不是使用额外的奖励模型来提供反馈。这大大简化了训练并降低了计算开销。
- 参数:
- forward(policy_chosen_logps: Tensor, policy_rejected_logps: Tensor, reference_chosen_logps: Tensor, reference_rejected_logps: Tensor) Tuple[Tensor, Tensor, Tensor] [source]¶
计算一批策略模型和参考模型对数概率的 DPO 损失。
- 参数:
policy_chosen_logps (torch.Tensor) – 策略模型对所选响应的对数概率。形状: (batch_size)
policy_rejected_logps (torch.Tensor) – 策略模型对被拒绝响应的对数概率。形状: (batch_size)
reference_chosen_logps (torch.Tensor) – 参考模型对所选响应的对数概率。形状: (batch_size)
reference_rejected_logps (torch.Tensor) – 参考模型对被拒绝响应的对数概率。形状: (batch_size)
- 返回值:
- 包含三个张量的元组
losses: 批次中每个示例的 DPO 损失。
chosen_rewards: 所选响应的奖励。
rejected_rewards: 被拒绝响应的奖励。
- 返回类型:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]