get_unmasked_sequence_lengths¶
- torchtune.training.get_unmasked_sequence_lengths(mask: Tensor) Tensor [源代码]¶
返回每个批次元素的序列长度,排除被掩码的 token。
- 参数:
mask (torch.Tensor) – 布尔掩码,形状为 [b x s],其中 True 表示要掩码的值。这通常是填充 token 的掩码,其中 True 表示填充 token。
- 返回:
序列索引 logits,形状为 [b]
- 返回类型:
Tensor
- 形状表示法
b = 批次大小
s = 序列长度
示例
>>> input_ids = torch.tensor([ ... [2, 4, 0, 0], ... [2, 4, 6, 0], ... [2, 4, 6, 9] ... ]) >>> mask = input_ids == 0 >>> mask tensor([[False, False, True, True], [False, False, False, True], [False, False, False, False]]) >>> get_unmasked_sequence_lengths(mask) tensor([1, 2, 3])