快捷方式

get_causal_mask_from_padding_mask

torchtune.generation.get_causal_mask_from_padding_mask(padding_mask: Tensor, target_seq_len: Optional[int] = None) Tensor[源代码]

将形状为 [bsz, seq_len] 的填充掩码转换为 [bsz, seq_len, seq_len] 的因果注意力掩码,适合由 scaled_dot_product_attention() 使用。如果提供了 target_seq_len,则将返回形状为 [bsz, seq_len, target_seq_len] 的掩码。当为静态 KV 缓存生成掩码时,这很有用,其中缓存已设置的最大长度大于当前序列。

参数:
  • padding_mask (torch.Tensor) – 布尔张量,其中 False 表示序列中相应的标记是填充标记,应在注意力中屏蔽,形状为 [bsz x seq_length]

  • target_seq_len (Optional[int]) – 用于创建注意力掩码的目标序列长度。默认为 None。

返回值:

形状为
  • [bsz, seq_length, seq_length] 或

  • [bsz, seq_length, target_seq_len](如果指定了 target_seq_len)。

返回类型:

torch.Tensor

引发:

AssertionError – 如果 target_seq_len > seq_len,即填充掩码的序列长度。

示例

>>> padding_mask = torch.tensor([[False, True, True, True]])
>>> get_causal_mask_from_padding_mask(padding_mask, target_seq_len=5)
tensor([[[ True, False, False, False, False],
          [False,  True, False, False, False],
          [False,  True,  True, False, False],
          [False,  True,  True,  True, False]]])
])

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源