快捷方式

RSOLoss

class torchtune.rlhf.loss.RSOLoss(gamma: float = 0.1)[source]

统计拒绝采样优化 (RSO) 或“hinge”损失模块: https://arxiv.org/abs/2309.06657. 来自论文的直觉

DPO 是人类偏好数据的逻辑回归,而 SLiC (https://arxiv.org/abs/2305.10425) 几乎等同于带有 hinge 损失的支持向量机 (SVM)。[RSO] 改进了 SLiC,使其成为 DPO 的 SVM 对应物。

基于 HF 的 TRL 库中的实现: https://github.com/huggingface/trl/blob/4dce042a3863db1d375358e8c8092b874b02934b/trl/trainer/dpo_trainer.py#L1141

参数:

gamma (float) – RSO 损失的等效温度参数(来自 DPO)。

forward(policy_chosen_logps: Tensor, policy_rejected_logps: Tensor, reference_chosen_logps: Tensor, reference_rejected_logps: Tensor) Tuple[Tensor, Tensor, Tensor][source]

计算一批策略和参考模型对数概率的 RSO 损失。

参数:
  • policy_chosen_logps (torch.Tensor) – 策略模型针对选定响应的对数概率。形状: (batch_size)

  • policy_rejected_logps (torch.Tensor) – 策略模型针对拒绝响应的对数概率。形状: (batch_size)

  • reference_chosen_logps (torch.Tensor) – 参考模型针对选定响应的对数概率。形状: (batch_size)

  • reference_rejected_logps (torch.Tensor) – 参考模型针对拒绝响应的对数概率。形状: (batch_size)

返回:

包含三个张量的元组
  • losses: 批次中每个示例的 RSO 损失。

  • chosen_rewards: 选定响应的奖励。

  • rejected_rewards: 拒绝响应的奖励。

返回类型:

Tuple[torch.Tensor, torch.Tensor, torch.Tensor]

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题解答

查看资源