truncate_sequence_at_first_stop_token¶
- torchtune.rlhf.truncate_sequence_at_first_stop_token(sequences: Tensor, stop_tokens: Tensor, fill_value: int = 0) Tuple[Tensor, Tensor] [源代码]¶
在第一个停止标记后截断一个或多个序列,并用
fill_value
进行填充。- 参数:
sequences (torch.Tensor) – 形状为 [batch_size, sequence_length] 或 [sequence_length] 的张量。
stop_tokens (torch.Tensor) – 包含停止标记的张量。
fill_value (int) – 在第一个停止标记后用于填充序列的值,通常是
pad_id
。
- 返回值:
- 形状与
sequences
相同的两个张量组成的元组 padding_mask (torch.Tensor): 一个布尔张量,其中
True
表示标记已被截断。sequences (torch.Tensor) 一个被截断和填充的序列张量。
- 形状与
- 返回类型:
Tuple[torch.Tensor, torch.Tensor]
示例
>>> stop_token_ids = torch.tensor([2, 869]) >>> fill_value = 0 >>> sequences = torch.tensor( >>> [ >>> [869, 30, 869], >>> [2, 30, 869], >>> [869, 30, 2], >>> [50, 30, 869], >>> [13, 30, 2], >>> [13, 30, 5], >>> [13, 2, 20], >>> [13, 2, 2], >>> [2, 2, 2], >>> ] >>> ) >>> eos_mask, truncated_sequences = rlhf.truncate_sequence_at_first_stop_token( >>> sequences, stop_token_ids, fill_value >>> ) >>> eos_mask >>> torch.tensor([ >>> [False, True, True], >>> [False, True, True], >>> [False, True, True], >>> [False, False, False], >>> [False, False, False], >>> [False, False, False], >>> [False, False, True], >>> [False, False, True], >>> [False, True, True], >>> ] >>> ) >>> truncated_sequences >>> torch.tensor([ >>> [869, 0, 0], >>> [2, 0, 0], >>> [869, 0, 0], >>> [50, 30, 869], >>> [13, 30, 2], >>> [13, 30, 5], >>> [13, 2, 0], >>> [13, 2, 0], >>> [2, 0, 0], >>> ] >>> )