get_causal_mask_from_padding_mask¶
- torchtune.generation.get_causal_mask_from_padding_mask(padding_mask: Tensor, target_seq_len: Optional[int] = None) Tensor [源代码]¶
将形状为
[bsz, seq_len]
的填充掩码转换为[bsz, seq_len, seq_len]
的因果注意力掩码,适合由scaled_dot_product_attention()
使用。如果提供了target_seq_len
,则将返回形状为[bsz, seq_len, target_seq_len]
的掩码。当为静态 KV 缓存生成掩码时,这很有用,其中缓存已设置的最大长度大于当前序列。- 参数:
padding_mask (torch.Tensor) – 布尔张量,其中 False 表示序列中相应的标记是填充标记,应在注意力中屏蔽,形状为 [bsz x seq_length]
target_seq_len (Optional[int]) – 用于创建注意力掩码的目标序列长度。默认为 None。
- 返回值:
- 形状为
[bsz, seq_length, seq_length] 或
[bsz, seq_length, target_seq_len](如果指定了
target_seq_len
)。
- 返回类型:
- 引发:
AssertionError – 如果
target_seq_len > seq_len
,即填充掩码的序列长度。
示例
>>> padding_mask = torch.tensor([[False, True, True, True]]) >>> get_causal_mask_from_padding_mask(padding_mask, target_seq_len=5) tensor([[[ True, False, False, False, False], [False, True, False, False, False], [False, True, True, False, False], [False, True, True, True, False]]]) ])