get_causal_mask_from_padding_mask¶
- torchtune.generation.get_causal_mask_from_padding_mask(padding_mask: Tensor, target_seq_len: Optional[int] = None) Tensor [source]¶
将形状为
[bsz, seq_len]
的填充掩码转换为[bsz, seq_len, seq_len]
的因果注意力掩码,适用于scaled_dot_product_attention()
使用。如果提供了target_seq_len
,则这将返回形状为[bsz, seq_len, target_seq_len]
的掩码。当为静态 KV 缓存生成掩码时,这非常有用,因为缓存设置的最大长度比当前序列长。- 参数:
padding_mask (torch.Tensor) – 布尔张量,其中 False 表示序列中对应的 token 是填充 token,应在注意力中被掩盖,形状为 [bsz x seq_length]
target_seq_len (Optional[int]) – 要创建注意力掩码的目标序列长度。默认为 None。
- 返回:
- 布尔因果掩码,形状为
[bsz, seq_length, seq_length] 或
[bsz, seq_length, target_seq_len] 如果指定了
target_seq_len
。
- 返回类型:
- 引发:
AssertionError – 如果
target_seq_len > seq_len
,即填充掩码的序列长度。
示例
>>> padding_mask = torch.tensor([[False, True, True, True]]) >>> get_causal_mask_from_padding_mask(padding_mask, target_seq_len=5) tensor([[[ True, False, False, False, False], [False, True, False, False, False], [False, True, True, False, False], [False, True, True, True, False]]]) ])