RSOLoss¶
- class torchtune.rlhf.loss.RSOLoss(gamma: float = 0.1)[源代码]¶
统计拒绝采样优化 (RSO) 或“铰链”损失模块:https://arxiv.org/abs/2309.06657。论文中的直觉
DPO 是对人类偏好数据进行的逻辑回归,而 SLiC (https://arxiv.org/abs/2305.10425) 几乎等同于具有铰链损失的支持向量机 (SVM)。[RSO] 改进[了] SLiC 作为 DPO 的 SVM 对应部分。
基于 HF 的 TRL 库中的实现:https://github.com/huggingface/trl/blob/4dce042a3863db1d375358e8c8092b874b02934b/trl/trainer/dpo_trainer.py#L1141
- 参数:
gamma (float) – RSO 损失的等效温度参数(来自 DPO)。
- forward(policy_chosen_logps: Tensor, policy_rejected_logps: Tensor, reference_chosen_logps: Tensor, reference_rejected_logps: Tensor) Tuple[Tensor, Tensor, Tensor] [源代码]¶
计算一批策略和参考模型对数概率的 RSO 损失。
- 参数:
policy_chosen_logps (torch.Tensor) – 策略模型对所选响应的对数概率。形状: (batch_size)
policy_rejected_logps (torch.Tensor) – 策略模型对被拒绝响应的对数概率。形状: (batch_size)
reference_chosen_logps (torch.Tensor) – 参考模型对所选响应的对数概率。形状: (batch_size)
reference_rejected_logps (torch.Tensor) – 参考模型对被拒绝响应的对数概率。形状: (batch_size)
- 返回值:
- 三个张量的元组
losses:批次中每个示例的 RSO 损失。
chosen_rewards:所选响应的奖励。
rejected_rewards:被拒绝响应的奖励。
- 返回类型:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]