快捷方式

MultiHeadAttention

class torchtune.modules.MultiHeadAttention(*, embed_dim: int, num_heads: int, num_kv_heads: int, head_dim: int, q_proj: Module, k_proj: Module, v_proj: Module, output_proj: Module, pos_embeddings: Optional[Module] = None, q_norm: Optional[Module] = None, k_norm: Optional[Module] = None, kv_cache: Optional[KVCache] = None, max_seq_len: int = 4096, is_causal: bool = True, attn_dropout: float = 0.0)[source]

多头注意力层,支持分组查询注意力 (GQA),该技术在 https://arxiv.org/abs/2305.13245v1 中引入。

GQA 是一种多头注意力 (MHA) 的变体,它通过将每 n 个查询头分组给一个键/值头来减少键/值头的数量。多查询注意力 (MQA) 是一种极端情况,所有查询头共享单个键/值头。

以下是 num_heads = 4 时 MHA、GQA 和 MQA 的示例

(文档来源:litgpt.Config)。

┌───┐┌───┐┌───┐┌───┐     ┌───┐    ┌───┐             ┌───┐
│ v ││ v ││ v ││ v │     │ v │    │ v │             │ v │
└───┘└───┘└───┘└───┘     └───┘    └───┘             └───┘
│    │    │    │         │        │                 │
┌───┐┌───┐┌───┐┌───┐     ┌───┐    ┌───┐             ┌───┐
│ k ││ k ││ k ││ k │     │ k │    │ k │             │ k │
└───┘└───┘└───┘└───┘     └───┘    └───┘             └───┘
│    │    │    │      ┌──┴──┐  ┌──┴──┐      ┌────┬──┴─┬────┐
┌───┐┌───┐┌───┐┌───┐  ┌───┐┌───┐┌───┐┌───┐  ┌───┐┌───┐┌───┐┌───┐
│ q ││ q ││ q ││ q │  │ q ││ q ││ q ││ q │  │ q ││ q ││ q ││ q │
└───┘└───┘└───┘└───┘  └───┘└───┘└───┘└───┘  └───┘└───┘└───┘└───┘
◀──────────────────▶  ◀──────────────────▶  ◀──────────────────▶
        MHA                    GQA                   MQA
n_kv_heads =4          n_kv_heads=2           n_kv_heads=1
参数:
  • embed_dim (int) – 模型的嵌入维度

  • num_heads (int) – 查询头的数量。对于 MHA,这也是键和值头的数量

  • num_kv_heads (int) – 键和值头的数量。用户应确保 num_heads % num_kv_heads == 0。对于标准 MHA,设置 num_kv_heads == num_heads;对于 GQA,设置 num_kv_heads < num_heads;对于 MQA,设置 num_kv_heads == 1

  • head_dim (int) – 每个头的维度,由 embed_dim // num_heads 计算得出。

  • q_proj (nn.Module) – 查询的投影层。

  • k_proj (nn.Module) – 键的投影层。

  • v_proj (nn.Module) – 值的投影层。

  • output_proj (nn.Module) – 输出的投影层。

  • pos_embeddings (Optional[nn.Module]) – 位置嵌入层,例如 RotaryPositionalEmbeddings。

  • q_norm (Optional[nn.Module]) – 查询的归一化层,例如 RMSNorm。对于解码,这在从 kv_cache 更新之前应用。这意味着它只支持 token 范围的归一化,不支持 batch 或序列范围的归一化。

  • k_norm (Optional[nn.Module]) – 键的归一化层,如果设置了 q_norm 则必须设置此项。

  • kv_cache (Optional[KVCache]) – 用于缓存键和值的 KVCache 对象

  • max_seq_len (int) – 模型支持的最大序列长度。计算 RoPE Cache 时需要此参数。默认值:4096。

  • is_causal (bool) – 在未提供 mask 时将默认 mask 设置为因果 mask

  • attn_dropout (float) – 传递给 scaled_dot_product_attention 函数的 dropout 值。默认值为 0.0。

抛出:

ValueError – 如果 num_heads % num_kv_heads != 0,或者 embed_dim % num_heads != 0,或者 attn_dropout < 0attn_dropout > 1,或者在定义 q_norm 时未定义 k_norm 或反之亦然

forward(x: Tensor, y: Optional[Tensor] = None, *, mask: Optional[Tensor] = None, input_pos: Optional[Tensor] = None) Tensor[source]
参数:
  • x (torch.Tensor) – 查询的输入张量,形状为 [b x s_x x d]

  • y (Optional[torch.Tensor]) – 第二个输入张量,形状为 [b x s_y x d],是 k 和 v 的输入。对于自注意力,x=y。仅在使用 kv_cache 时可选。

  • mask (Optional[_MaskType]) –

    用于在查询-键乘法之后、softmax 之前对分数进行掩码。可以是

    一个布尔张量,形状为 [b x s x s][b x s x self.encoder_max_cache_seq_len][b x s x self.decoder_max_cache_seq_len](如果使用 encoder/decoder 层进行 KV 缓存)。在行 i 和列 j 中值为 True 表示 token i 注意 token j。值为 False 表示 token i 不注意 token j。如果未指定 mask,默认使用因果 mask。

    一个 BlockMask,用于通过 create_block_mask 创建的 packed sequence 中的文档掩码。在使用 block mask 计算注意力时,我们使用 flex_attention()。默认值为 None。

  • input_pos (Optional[torch.Tensor]) – 可选张量,包含每个 token 的位置 ID。在训练期间,这用于指示打包时每个 token 相对于其样本的位置,形状为 [b x s]。在推理期间,这指示当前 token 的位置。如果为 none,则假定 token 的索引为其位置 ID。默认值为 None。

抛出:

ValueError – 如果没有 y 输入且未启用 kv_cache

返回:

应用注意力后的输出张量

返回类型:

torch.Tensor

张量形状中使用的表示法
  • b: batch size

  • s_x: sequence length for x

  • s_y: sequence length for y

  • n_h: num heads

  • n_kv: num kv heads

  • d: embed dim

  • h_d: head dim

reset_cache()[source]

重置键值缓存。

setup_cache(batch_size: int, dtype: dtype, max_seq_len: int) None[source]

设置用于注意力计算的键值缓存。如果在 kv_cache 已经设置后再次调用,则会跳过。

参数:
  • batch_size (int) – 缓存的 batch size。

  • dtype (torch.dpython:type) – 缓存的 dtype。

  • max_seq_len (int) – 模型将运行的最大序列长度。

文档

查阅 PyTorch 的完整开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获取问题解答

查看资源