快捷方式

truncate_sequence_at_first_stop_token

torchtune.rlhf.truncate_sequence_at_first_stop_token(sequences: Tensor, stop_tokens: Tensor, fill_value: int = 0) Tuple[Tensor, Tensor][源代码]

在第一个停止标记后截断一个或多个序列,并用 fill_value 进行填充。

参数:
  • sequences (torch.Tensor) – 形状为 [batch_size, sequence_length] 或 [sequence_length] 的张量。

  • stop_tokens (torch.Tensor) – 包含停止标记的张量。

  • fill_value (int) – 在第一个停止标记后用于填充序列的值,通常是 pad_id

返回值:

形状与 sequences 相同的两个张量组成的元组
  • padding_mask (torch.Tensor): 一个布尔张量,其中 True 表示标记已被截断。

  • sequences (torch.Tensor) 一个被截断和填充的序列张量。

返回类型:

Tuple[torch.Tensor, torch.Tensor]

示例

>>> stop_token_ids = torch.tensor([2, 869])
>>> fill_value = 0
>>> sequences = torch.tensor(
>>>     [
>>>         [869, 30, 869],
>>>         [2, 30, 869],
>>>         [869, 30, 2],
>>>         [50, 30, 869],
>>>         [13, 30, 2],
>>>         [13, 30, 5],
>>>         [13, 2, 20],
>>>         [13, 2, 2],
>>>         [2, 2, 2],
>>>     ]
>>> )
>>> eos_mask, truncated_sequences = rlhf.truncate_sequence_at_first_stop_token(
>>>     sequences, stop_token_ids, fill_value
>>> )
>>> eos_mask
>>> torch.tensor([
>>>         [False, True, True],
>>>         [False, True, True],
>>>         [False, True, True],
>>>         [False, False, False],
>>>         [False, False, False],
>>>         [False, False, False],
>>>         [False, False, True],
>>>         [False, False, True],
>>>         [False, True, True],
>>>     ]
>>> )
>>> truncated_sequences
>>> torch.tensor([
>>>         [869, 0, 0],
>>>         [2, 0, 0],
>>>         [869, 0, 0],
>>>         [50, 30, 869],
>>>         [13, 30, 2],
>>>         [13, 30, 5],
>>>         [13, 2, 0],
>>>         [13, 2, 0],
>>>         [2, 0, 0],
>>>     ]
>>> )

文档

访问 PyTorch 全面的开发者文档

查看文档

教程

获取面向初学者和高级开发人员的深度教程

查看教程

资源

查找开发资源并获得问题解答

查看资源