快捷方式

get_unmasked_sequence_lengths

torchtune.training.get_unmasked_sequence_lengths(mask: Tensor) Tensor[源代码]

返回每个批次元素的序列长度,排除被掩码的 token。

参数:

mask (torch.Tensor) – 布尔掩码,形状为 [b x s],其中 True 表示要掩码的值。这通常是填充 token 的掩码,其中 True 表示填充 token。

返回:

序列索引 logits,形状为 [b]

返回类型:

Tensor

形状表示法
  • b = 批次大小

  • s = 序列长度

示例

>>> input_ids = torch.tensor([
...        [2, 4, 0, 0],
...        [2, 4, 6, 0],
...        [2, 4, 6, 9]
...    ])
>>> mask = input_ids == 0
>>> mask
tensor([[False, False,  True,  True],
        [False, False, False,  True],
        [False, False, False, False]])
>>> get_unmasked_sequence_lengths(mask)
tensor([1, 2, 3])

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得您的问题解答

查看资源