快捷方式

get_unmasked_sequence_lengths

torchtune.training.get_unmasked_sequence_lengths(mask: Tensor) Tensor[source]

返回每个批次元素的序列长度,不包括掩码标记。

参数:

mask (torch.Tensor) – 形状为 [b x s] 的布尔掩码,其中 True 表示要屏蔽的值。这通常是填充标记的掩码,其中 True 表示填充标记。

返回值:

形状为 [b] 的序列索引 logits

返回类型:

张量

形状表示法
  • b = 批次大小

  • s = 序列长度

示例

>>> input_ids = torch.tensor([
...        [2, 4, 0, 0],
...        [2, 4, 6, 0],
...        [2, 4, 6, 9]
...    ])
>>> mask = input_ids == 0
>>> mask
tensor([[False, False,  True,  True],
        [False, False, False,  True],
        [False, False, False, False]])
>>> get_unmasked_sequence_lengths(mask)
tensor([1, 2, 3])

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源