truncate_sequence_at_first_stop_token¶
- torchtune.rlhf.truncate_sequence_at_first_stop_token(sequences: Tensor, stop_tokens: Tensor, fill_value: int = 0) Tuple[Tensor, Tensor] [source]¶
在第一个停止标记之后截断序列(s),并用
fill_value
填充。- 参数::
sequences (torch.Tensor) – 形状为 [batch_size, sequence_length] 或 [sequence_length] 的张量。
stop_tokens (torch.Tensor) – 包含停止标记的张量。
fill_value (int) – 在第一个停止标记之后用它来填充序列的值,通常是
pad_id
。
- 返回值::
- 两个张量的元组,与
sequences
的形状相同 padding_mask (torch.Tensor):一个 bool 张量,其中 True 表示该标记已截断。
sequences (torch.Tensor) 一个截断和填充序列的张量。
- 两个张量的元组,与
- 返回类型::
Tuple[torch.Tensor, torch.Tensor]
示例
>>> stop_token_ids = torch.tensor([2, 869]) >>> fill_value = 0 >>> sequences = torch.tensor( >>> [ >>> [869, 30, 869], >>> [2, 30, 869], >>> [869, 30, 2], >>> [50, 30, 869], >>> [13, 30, 2], >>> [13, 30, 5], >>> [13, 2, 20], >>> [13, 2, 2], >>> [2, 2, 2], >>> ] >>> ) >>> eos_mask, truncated_sequences = rlhf.truncate_sequence_at_first_stop_token( >>> sequences, stop_token_ids, fill_value >>> ) >>> eos_mask >>> torch.tensor([ >>> [False, True, True], >>> [False, True, True], >>> [False, True, True], >>> [False, False, False], >>> [False, False, False], >>> [False, False, False], >>> [False, False, True], >>> [False, False, True], >>> [False, True, True], >>> ] >>> ) >>> truncated_sequences >>> torch.tensor([ >>> [869, 0, 0], >>> [2, 0, 0], >>> [869, 0, 0], >>> [50, 30, 869], >>> [13, 30, 2], >>> [13, 30, 5], >>> [13, 2, 0], >>> [13, 2, 0], >>> [2, 0, 0], >>> ] >>> )