MultiHeadAttention¶
- class torchtune.modules.MultiHeadAttention(*, embed_dim: int, num_heads: int, num_kv_heads: int, head_dim: int, q_proj: Module, k_proj: Module, v_proj: Module, output_proj: Module, pos_embeddings: Optional[Module] = None, q_norm: Optional[Module] = None, k_norm: Optional[Module] = None, kv_cache: Optional[KVCache] = None, max_seq_len: int = 4096, is_causal: bool = True, attn_dropout: float = 0.0)[source]¶
支持 https://arxiv.org/abs/2305.13245v1 中引入的分组查询注意力 (GQA) 的多头注意力层。
GQA 是多头注意力 (MHA) 的一种变体,它通过为每个键和值头分组 n 个查询头来使用比查询头更少的键/值头。多查询注意力是一个极端版本,其中我们有一个由所有查询头共享的键和值头。
以下是 num_heads = 4 的 MHA、GQA 和 MQA 示例
(文档来源: litgpt.Config).
┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐ │ v ││ v ││ v ││ v │ │ v │ │ v │ │ v │ └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘ │ │ │ │ │ │ │ ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐ │ k ││ k ││ k ││ k │ │ k │ │ k │ │ k │ └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘ │ │ │ │ ┌──┴──┐ ┌──┴──┐ ┌────┬──┴─┬────┐ ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐ │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │ └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘ ◀──────────────────▶ ◀──────────────────▶ ◀──────────────────▶ MHA GQA MQA n_kv_heads =4 n_kv_heads=2 n_kv_heads=1
- 参数:
embed_dim (int) – 模型的嵌入维度
num_heads (int) – 查询头的数量。对于 MHA,这也是键和值的头部数量
num_kv_heads (int) – 键和值头的数量。用户应确保
num_heads % num_kv_heads == 0
。对于标准 MHA,设置num_kv_heads == num_heads
,对于 GQAnum_kv_heads < num_heads
,对于 MQA 设置num_kv_heads == 1
。head_dim (int) – 每个头的维度,由
embed_dim // num_heads
计算。q_proj (nn.Module) – 查询的投影层。
k_proj (nn.Module) – 键的投影层。
v_proj (nn.Module) – 值的投影层。
output_proj (nn.Module) – 输出的投影层。
pos_embeddings (可选[nn.Module]) – 位置嵌入层,例如 RotaryPositionalEmbeddings。
q_norm (可选[nn.Module]) – 查询的归一化层,例如 RMSNorm。在解码过程中,此操作在从 kv_cache 更新之前应用。这意味着它仅支持 token 级别的归一化,而不支持批次或序列级别的归一化。
k_norm (可选[nn.Module]) – 密钥的归一化层,如果设置了 q_norm,则必须设置此参数。
kv_cache (可选[KVCache]) – 用于缓存键和值的 KVCache 对象
max_seq_len (int) – 模型支持的最大序列长度。这需要计算 RoPE Cache。默认值:4096。
is_causal (bool) – 当未提供掩码时,将默认掩码设置为因果掩码
attn_dropout (float) – 传递给 scaled_dot_product_attention 函数的 dropout 值。如果 self.training 为 False,则忽略此参数。默认值为 0.0。
- 引发:
ValueError – 如果
num_heads % num_kv_heads != 0
ValueError – 如果
embed_dim % num_heads != 0
ValueError – 如果
attn_dropout < 0
或attn_dropout > 1
ValueError – 如果定义了 q_norm 但未定义 k_norm,反之亦然
- forward(x: Tensor, y: Optional[Tensor] = None, *, mask: Optional[Tensor] = None, input_pos: Optional[Tensor] = None) Tensor [source]¶
- 参数:
x (torch.Tensor) – 查询的输入张量,形状为 [b x s_x x d]
y (可选[torch.Tensor]) – 第二个输入张量,形状为 [b x s_y x d],是 k 和 v 的输入。对于自注意力,x=y。仅在启用 kv_cache 时可选。
mask (可选[_MaskType]) –
用于在查询-键乘法之后以及 softmax 之前掩盖分数。可以是:
形状为
[b x s x s]
、[b x s x self.encoder_max_cache_seq_len]
或[b x s x self.encoder_max_cache_seq_len]
的布尔张量(如果使用具有编码器/解码器层的 KV 缓存)。第i
行和第j
列中的 True 值表示 tokeni
注意到 tokenj
。False 值表示 tokeni
未注意到 tokenj
。如果未指定掩码,则默认使用因果掩码。用于通过 create_block_mask 创建的打包序列中的文档掩码的
BlockMask
。在使用块掩码计算注意力时,我们使用flex_attention()
。默认为 None。input_pos (可选[torch.Tensor]) – 包含每个 token 的位置 ID 的可选张量。在训练期间,用于指示打包时每个 token 相对于其样本的位置,形状为 [b x s]。在推理过程中,指示当前 token 的位置。如果没有,则假设 token 的索引是其位置 ID。默认为 None。
- 引发:
ValueError – 如果没有
y
输入且未启用kv_cache
。- 返回:
应用了注意力后的输出张量
- 返回类型:
- 张量形状使用的符号
b:批次大小
s_x:x 的序列长度
s_y:y 的序列长度
n_h:头数
n_kv:kv 头数
d:嵌入维度
h_d:头维度