快捷方式

RSOLoss

class torchtune.rlhf.loss.RSOLoss(gamma: float = 0.1)[源代码]

统计拒绝采样优化 (RSO) 或“合页”损失模块:https://arxiv.org/abs/2309.06657。论文中的直觉

DPO 是对人类偏好数据的逻辑回归,而 SLiC (https://arxiv.org/abs/2305.10425) 几乎等同于使用合页损失的支持向量机 (SVM)。[RSO] 作为 DPO 的 SVM 对应物,改进了 SLiC。

基于 HF TRL 库中的实现:https://github.com/huggingface/trl/blob/4dce042a3863db1d375358e8c8092b874b02934b/trl/trainer/dpo_trainer.py#L1141

参数:

gamma (float) – RSO 损失的等效温度参数(来自 DPO)。

forward(policy_chosen_logps: Tensor, policy_rejected_logps: Tensor, reference_chosen_logps: Tensor, reference_rejected_logps: Tensor) Tuple[Tensor, Tensor, Tensor][源代码]

计算一批策略和参考模型对数概率的 RSO 损失。

参数:
  • policy_chosen_logps (torch.Tensor) – 策略模型对选中响应的对数概率。形状:(batch_size)

  • policy_rejected_logps (torch.Tensor) – 策略模型对拒绝响应的对数概率。形状:(batch_size)

  • reference_chosen_logps (torch.Tensor) – 参考模型对选中响应的对数概率。形状:(batch_size)

  • reference_rejected_logps (torch.Tensor) – 参考模型对拒绝响应的对数概率。形状:(batch_size)

返回值:

包含三个张量的元组
  • losses:批次中每个样本的 RSO 损失。

  • chosen_rewards:选中响应的奖励。

  • rejected_rewards:拒绝响应的奖励。

返回类型:

Tuple[torch.Tensor, torch.Tensor, torch.Tensor]

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题解答

查看资源