get_causal_mask_from_padding_mask¶
- torchtune.generation.get_causal_mask_from_padding_mask(padding_mask: Tensor, target_seq_len: Optional[int] = None) Tensor [源码]¶
将形状为
[bsz, seq_len]
的填充 mask 转换为适用于scaled_dot_product_attention()
的形状为[bsz, seq_len, seq_len]
的因果注意力 mask。如果提供了target_seq_len
,这将返回一个形状为[bsz, seq_len, target_seq_len]
的 mask。这在为静态 KV 缓存生成 mask 时非常有用,静态 KV 缓存的最大长度可能比当前序列长度更长。- 参数:
padding_mask (torch.Tensor) – 布尔张量,其中 False 表示序列中对应的 token 是填充 token,应在注意力中被 masked out,形状为 [bsz x seq_length]
target_seq_len (Optional[int]) – 用于创建注意力 mask 的目标序列长度。默认为 None。
- 返回:
- 形状为以下之一的布尔因果 mask
[bsz, seq_length, seq_length] 或
如果指定了
target_seq_len
,则为 [bsz, seq_length, target_seq_len]。
- 返回类型:
- 抛出:
AssertionError – 如果
target_seq_len < seq_len
,即小于填充 mask 的序列长度。
示例
>>> padding_mask = torch.tensor([[False, True, True, True]]) >>> get_causal_mask_from_padding_mask(padding_mask, target_seq_len=5) tensor([[[ True, False, False, False, False], [False, True, False, False, False], [False, True, True, False, False], [False, True, True, True, False]]]) ])