快捷方式

get_causal_mask_from_padding_mask

torchtune.generation.get_causal_mask_from_padding_mask(padding_mask: Tensor, target_seq_len: Optional[int] = None) Tensor[源码]

将形状为 [bsz, seq_len] 的填充 mask 转换为适用于 scaled_dot_product_attention() 的形状为 [bsz, seq_len, seq_len] 的因果注意力 mask。如果提供了 target_seq_len,这将返回一个形状为 [bsz, seq_len, target_seq_len] 的 mask。这在为静态 KV 缓存生成 mask 时非常有用,静态 KV 缓存的最大长度可能比当前序列长度更长。

参数
  • padding_mask (torch.Tensor) – 布尔张量,其中 False 表示序列中对应的 token 是填充 token,应在注意力中被 masked out,形状为 [bsz x seq_length]

  • target_seq_len (Optional[int]) – 用于创建注意力 mask 的目标序列长度。默认为 None。

返回

形状为以下之一的布尔因果 mask
  • [bsz, seq_length, seq_length] 或

  • 如果指定了 target_seq_len,则为 [bsz, seq_length, target_seq_len]。

返回类型

torch.Tensor

抛出

AssertionError – 如果 target_seq_len < seq_len,即小于填充 mask 的序列长度。

示例

>>> padding_mask = torch.tensor([[False, True, True, True]])
>>> get_causal_mask_from_padding_mask(padding_mask, target_seq_len=5)
tensor([[[ True, False, False, False, False],
          [False,  True, False, False, False],
          [False,  True,  True, False, False],
          [False,  True,  True,  True, False]]])
])

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获取问题解答

查看资源