快捷方式

truncate_sequence_at_first_stop_token

torchtune.rlhf.truncate_sequence_at_first_stop_token(sequences: Tensor, stop_tokens: Tensor, fill_value: int = 0) Tuple[Tensor, Tensor][source]

在第一个停止标记之后截断序列(s),并用 fill_value 填充。

参数::
  • sequences (torch.Tensor) – 形状为 [batch_size, sequence_length] 或 [sequence_length] 的张量。

  • stop_tokens (torch.Tensor) – 包含停止标记的张量。

  • fill_value (int) – 在第一个停止标记之后用它来填充序列的值,通常是 pad_id

返回值::

两个张量的元组,与 sequences 的形状相同
  • padding_mask (torch.Tensor):一个 bool 张量,其中 True 表示该标记已截断。

  • sequences (torch.Tensor) 一个截断和填充序列的张量。

返回类型::

Tuple[torch.Tensor, torch.Tensor]

示例

>>> stop_token_ids = torch.tensor([2, 869])
>>> fill_value = 0
>>> sequences = torch.tensor(
>>>     [
>>>         [869, 30, 869],
>>>         [2, 30, 869],
>>>         [869, 30, 2],
>>>         [50, 30, 869],
>>>         [13, 30, 2],
>>>         [13, 30, 5],
>>>         [13, 2, 20],
>>>         [13, 2, 2],
>>>         [2, 2, 2],
>>>     ]
>>> )
>>> eos_mask, truncated_sequences = rlhf.truncate_sequence_at_first_stop_token(
>>>     sequences, stop_token_ids, fill_value
>>> )
>>> eos_mask
>>> torch.tensor([
>>>         [False, True, True],
>>>         [False, True, True],
>>>         [False, True, True],
>>>         [False, False, False],
>>>         [False, False, False],
>>>         [False, False, False],
>>>         [False, False, True],
>>>         [False, False, True],
>>>         [False, True, True],
>>>     ]
>>> )
>>> truncated_sequences
>>> torch.tensor([
>>>         [869, 0, 0],
>>>         [2, 0, 0],
>>>         [869, 0, 0],
>>>         [50, 30, 869],
>>>         [13, 30, 2],
>>>         [13, 30, 5],
>>>         [13, 2, 0],
>>>         [13, 2, 0],
>>>         [2, 0, 0],
>>>     ]
>>> )

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源