MultiheadAttention¶
- class torch.nn.MultiheadAttention(embed_dim, num_heads, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, batch_first=False, device=None, dtype=None)[源代码]¶
允许模型共同关注来自不同表示子空间的信息。
论文中描述的方法:注意力就是你所需要的一切。
多头注意力定义为
其中 .
nn.MultiHeadAttention
将在可能的情况下使用scaled_dot_product_attention()
的优化实现。除了支持新的
scaled_dot_product_attention()
函数外,为了加快推理速度,MHA 将使用快速路径推理并支持嵌套张量,当且仅当正在计算自注意力(即,
query
、key
和value
是相同的张量)。输入被批处理(3D)且
batch_first==True
要么自动梯度已禁用(使用
torch.inference_mode
或torch.no_grad
),要么没有张量参数requires_grad
训练已禁用(使用
.eval()
)add_bias_kv
为False
add_zero_attn
为False
kdim
和vdim
等于embed_dim
如果传递了 NestedTensor,则不会传递
key_padding_mask
或attn_mask
自动转换已禁用
如果正在使用优化的推理快速路径实现,则可以将 NestedTensor 传递给
query
/key
/value
以更有效地表示填充,而不是使用填充掩码。在这种情况下,将返回 NestedTensor,并且可以预期额外加速与输入中填充部分的比例成正比。- 参数
embed_dim – 模型的总维度。
num_heads – 并行注意力头的数量。请注意,
embed_dim
将在num_heads
中分割(即,每个头的维度将为embed_dim // num_heads
)。dropout –
attn_output_weights
上的 dropout 概率。默认值:0.0
(无 dropout)。bias – 如果指定,则向输入/输出投影层添加偏差。默认值:
True
。add_bias_kv – 如果指定,则在 dim=0 处向键和值序列添加偏差。默认值:
False
。add_zero_attn – 如果指定,则在 dim=1 处向键和值序列添加一个新的零批次。默认值:
False
。kdim – 键的特征总数。默认值:
None
(使用kdim=embed_dim
)。vdim – 值的特征总数。默认值:
None
(使用vdim=embed_dim
)。batch_first – 如果为
True
,则输入和输出张量将作为 (batch, seq, feature) 提供。默认值:False
(seq, batch, feature)。
示例
>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) >>> attn_output, attn_output_weights = multihead_attn(query, key, value)
- forward(query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None, average_attn_weights=True, is_causal=False)[source]¶
使用查询、键和值嵌入计算注意力输出。
支持用于填充、掩码和注意力权重的可选参数。
- 参数
query (Tensor) – 查询嵌入,形状为 (对于非批处理输入), 当
batch_first=False
时,或 当batch_first=True
时,其中 是目标序列长度, 是批次大小, 是查询嵌入维度embed_dim
。查询与键值对进行比较以生成输出。更多细节请参阅“Attention Is All You Need”。key (Tensor) – 键嵌入,形状为 (对于非批处理输入), 当
batch_first=False
时,或 当batch_first=True
时,其中 是源序列长度, 是批次大小, 是键嵌入维度kdim
。更多细节请参阅“Attention Is All You Need”。value (张量) – 值嵌入,形状为 (对于非批处理输入), 当
batch_first=False
或 当batch_first=True
,其中 是源序列长度, 是批大小, 是值嵌入维度vdim
。更多细节请参考“Attention Is All You Need”。key_padding_mask (可选[张量]) – 如果指定,则为形状为 的掩码,指示为关注目的而忽略
key
中的哪些元素(即将其视为“填充”)。对于非批处理的 query,形状应为 。支持二进制和浮点掩码。对于二进制掩码,True
值表示将忽略相应的key
值以进行关注。对于浮点掩码,它将直接添加到相应的key
值中。need_weights (布尔值) – 如果指定,则除了
attn_outputs
之外,还返回attn_output_weights
。设置need_weights=False
以使用优化的scaled_dot_product_attention
并获得 MHA 的最佳性能。默认值:True
。attn_mask (可选[张量]) – 如果指定,则为 2D 或 3D 掩码,防止关注某些位置。必须为形状
) (L, S) 或 ,其中 是批大小, 是目标序列长度, 是源序列长度。2D 掩码将跨批广播,而 3D 掩码允许批中每个条目使用不同的掩码。支持二进制和浮点掩码。对于二进制掩码,True
值表示不允许关注相应的位置。对于浮点掩码,掩码值将添加到注意力权重中。如果同时提供 attn_mask 和 key_padding_mask,则它们的类型应匹配。average_attn_weights (布尔值) – 如果为真,则表示返回的
attn_weights
应在头部之间取平均值。否则,attn_weights
将按每个头部单独提供。请注意,此标志仅在need_weights=True
时有效。默认值:True
(即在头部之间平均权重)is_causal (布尔值) – 如果指定,则将因果掩码应用为注意力掩码。默认值:
False
。警告:is_causal
提供一个提示,表明attn_mask
是因果掩码。提供不正确的提示可能导致执行错误,包括向前和向后兼容性。
- 返回类型
- 输出
attn_output - 注意力输出,形状为 (当输入未批处理时),(当
batch_first=False
时)或 (当batch_first=True
时),其中 是目标序列长度, 是批次大小, 是嵌入维度embed_dim
。attn_output_weights - 仅当
need_weights=True
时返回。如果average_attn_weights=True
,则返回跨头的平均注意力权重,形状为 (当输入未批处理时)或 ,其中 是批次大小, 是目标序列长度, 是源序列长度。如果average_attn_weights=False
,则返回每个头的注意力权重,形状为 (当输入未批处理时)或 .
注意
对于未批处理的输入,batch_first 参数将被忽略。
- merge_masks(attn_mask, key_padding_mask, query)[source]¶
确定掩码类型并在必要时组合掩码。
如果仅提供一个掩码,则将返回该掩码和相应的掩码类型。如果同时提供两个掩码,则将两者扩展到形状
(batch_size, num_heads, seq_len, seq_len)
,并使用逻辑or
组合,并将掩码类型返回为 2 :param attn_mask: 形状为(seq_len, seq_len)
的注意力掩码,掩码类型为 0 :param key_padding_mask: 形状为(batch_size, seq_len)
的填充掩码,掩码类型为 1 :param query: 形状为(batch_size, seq_len, embed_dim)
的查询嵌入- 返回
合并后的掩码 mask_type: 合并后的掩码类型(0、1 或 2)
- 返回类型
merged_mask