快捷方式

get_unmasked_sequence_lengths

torchtune.training.get_unmasked_sequence_lengths(mask: Tensor) Tensor[源码]

返回每个批次元素的序列长度(0-索引),排除被遮盖的 token。

参数:

mask (torch.Tensor) – 形状为 [b x s] 的布尔掩码,其中 True 表示需要被遮盖的值。这通常是用于填充 token 的掩码,True 表示一个填充 token。

返回:

形状为 [b] 的序列索引 logits

返回类型:

Tensor

形状表示法
  • b = 批次大小

  • s = 序列长度

示例

>>> input_ids = torch.tensor([
...        [2, 4, 0, 0],
...        [2, 4, 6, 0],
...        [2, 4, 6, 9]
...    ])
>>> mask = input_ids == 0
>>> mask
tensor([[False, False,  True,  True],
        [False, False, False,  True],
        [False, False, False, False]])
>>> get_unmasked_sequence_lengths(mask)
tensor([1, 2, 3])

文档

查阅 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得解答

查看资源