快捷方式

padded_collate_dpo

torchtune.data.padded_collate_dpo(batch: List[Dict[str, List[int]]], padding_idx: int = 0, ignore_idx: int = - 100) Tuple[Tensor, Tensor][源代码]

为直接偏好优化 (DPO) 填充一批序列。

此函数接收一批序列,其中每个序列都表示为一个字典,包含多个键值对。每个键对应于不同的序列组件,例如 input_ids 或 labels。

参数:
  • batch (List[Dict[str, List[int]]]) – 字典列表,其中每个字典表示一个具有多个组件的序列,“chosen_input_ids”、“chosen_labels”、“rejected_input_ids”和“rejected_labels”是必需的。

  • padding_idx (int) – 输入 id 的填充索引。默认为 0。

  • ignore_idx (int) – 标签的填充索引。默认为 -100。

返回值:

包含连接和填充的输入 id 和标签的元组。

返回类型:

Tuple[torch.Tensor, torch.Tensor]

示例

>>> batch = [
>>>    {'chosen_input_ids': [1, 2, 3], 'rejected_input_ids': [4, 5],
>>>      'chosen_labels': [6, 7, 8], 'rejected_labels': [9, 10]},
>>>    {'chosen_input_ids': [11, 12], 'rejected_input_ids': [13, 14, 15],
>>>      'chosen_labels': [16, 17], 'rejected_labels': [18, 19, 20]},
>>> ]
>>> padded_collate_dpo(batch)
>>> (tensor([[ 1,  2,  3],
>>>          [11, 12,  0],
>>>          [ 4,  5,  0],
>>>          [13, 14, 15]]),
>>>  tensor([[ 6,  7,  8],
>>>          [16, 17, -100],
>>>          [ 9, 10, -100],
>>>          [18, 19, 20]]))

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源