get_unmasked_sequence_lengths¶
- torchtune.training.get_unmasked_sequence_lengths(mask: Tensor) Tensor [source]¶
返回每个批次元素的序列长度,不包括掩码标记。
- 参数:
mask (torch.Tensor) – 形状为 [b x s] 的布尔掩码,其中 True 表示要屏蔽的值。这通常是填充标记的掩码,其中 True 表示填充标记。
- 返回值:
形状为 [b] 的序列索引 logits
- 返回类型:
张量
- 形状表示法
b = 批次大小
s = 序列长度
示例
>>> input_ids = torch.tensor([ ... [2, 4, 0, 0], ... [2, 4, 6, 0], ... [2, 4, 6, 9] ... ]) >>> mask = input_ids == 0 >>> mask tensor([[False, False, True, True], [False, False, False, True], [False, False, False, False]]) >>> get_unmasked_sequence_lengths(mask) tensor([1, 2, 3])