快捷方式

RSOLoss

class torchtune.rlhf.loss.RSOLoss(gamma: float = 0.1)[源代码]

统计拒绝采样优化 (RSO) 或“铰链”损失模块:https://arxiv.org/abs/2309.06657。论文中的直觉

DPO 是对人类偏好数据进行的逻辑回归,而 SLiC (https://arxiv.org/abs/2305.10425) 几乎等同于具有铰链损失的支持向量机 (SVM)。[RSO] 改进[了] SLiC 作为 DPO 的 SVM 对应部分。

基于 HF 的 TRL 库中的实现:https://github.com/huggingface/trl/blob/4dce042a3863db1d375358e8c8092b874b02934b/trl/trainer/dpo_trainer.py#L1141

参数:

gamma (float) – RSO 损失的等效温度参数(来自 DPO)。

forward(policy_chosen_logps: Tensor, policy_rejected_logps: Tensor, reference_chosen_logps: Tensor, reference_rejected_logps: Tensor) Tuple[Tensor, Tensor, Tensor][源代码]

计算一批策略和参考模型对数概率的 RSO 损失。

参数:
  • policy_chosen_logps (torch.Tensor) – 策略模型对所选响应的对数概率。形状: (batch_size)

  • policy_rejected_logps (torch.Tensor) – 策略模型对被拒绝响应的对数概率。形状: (batch_size)

  • reference_chosen_logps (torch.Tensor) – 参考模型对所选响应的对数概率。形状: (batch_size)

  • reference_rejected_logps (torch.Tensor) – 参考模型对被拒绝响应的对数概率。形状: (batch_size)

返回值:

三个张量的元组
  • losses:批次中每个示例的 RSO 损失。

  • chosen_rewards:所选响应的奖励。

  • rejected_rewards:被拒绝响应的奖励。

返回类型:

Tuple[torch.Tensor, torch.Tensor, torch.Tensor]

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源