MultiheadAttention¶
- class torch.nn.MultiheadAttention(embed_dim, num_heads, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, batch_first=False, device=None, dtype=None)[source][source]¶
允许模型共同关注来自不同表示子空间的信息。
注
请参阅本教程,深入讨论 PyTorch 为构建您自己的 Transformer 层提供的性能构建块。
论文中描述的方法:Attention Is All You Need。
多头注意力定义为
其中 。
nn.MultiheadAttention将尽可能使用scaled_dot_product_attention()的优化实现。除了支持新的
scaled_dot_product_attention()函数以加速推理外,MHA 还将使用快速路径推理,并支持嵌套张量,如果满足以下条件:正在计算自注意力(即,
query、key和value是相同的张量)。输入是批处理的 (3D),且
batch_first==True自动微分已禁用(使用
torch.inference_mode或torch.no_grad),或者没有张量参数requires_grad训练已禁用(使用
.eval())add_bias_kv为Falseadd_zero_attn为Falsekdim和vdim等于embed_dim如果传递了 NestedTensor,则不传递
key_padding_mask和attn_mask自动类型转换已禁用
如果正在使用优化的推理快速路径实现,则可以为
query/key/value传递 NestedTensor,以比使用填充掩码更有效地表示填充。在这种情况下,将返回 NestedTensor,并且可以预期额外的加速与输入中填充部分的比例成正比。- 参数
embed_dim – 模型的总维度。
num_heads – 并行注意力头的数量。请注意,
embed_dim将在num_heads之间拆分(即,每个头的维度为embed_dim // num_heads)。dropout –
attn_output_weights上的 dropout 概率。默认值:0.0(无 dropout)。bias – 如果指定,则向输入/输出投影层添加偏置。默认值:
True。add_bias_kv – 如果指定,则向 dim=0 处的键和值序列添加偏置。默认值:
False。add_zero_attn – 如果指定,则向 dim=1 处的键和值序列添加新的零批次。默认值:
False。kdim – 键的特征总数。默认值:
None(使用kdim=embed_dim)。vdim – 值的特征总数。默认值:
None(使用vdim=embed_dim)。batch_first – 如果为
True,则输入和输出张量以 (batch, seq, feature) 形式提供。默认值:False(seq, batch, feature)。
示例
>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) >>> attn_output, attn_output_weights = multihead_attn(query, key, value)
- forward(query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None, average_attn_weights=True, is_causal=False)[source][source]¶
使用查询、键和值嵌入计算注意力输出。
支持填充、掩码和注意力权重的可选参数。
- 参数
query (Tensor) – 查询嵌入,形状为 ,用于非批处理输入;当
batch_first=False时为 ;当batch_first=True时为 ,其中 是目标序列长度, 是批大小, 是查询嵌入维度embed_dim。查询与键值对进行比较以生成输出。有关更多详细信息,请参阅“Attention Is All You Need”。key (Tensor) – 键嵌入,形状为 ,用于非批处理输入;当
batch_first=False时为 ;当batch_first=True时为 ,其中 是源序列长度, 是批大小, 是键嵌入维度kdim。有关更多详细信息,请参阅“Attention Is All You Need”。value (Tensor) – 值嵌入,形状为 ,用于非批处理输入;当
batch_first=False时为 ;当batch_first=True时为 ,其中 是源序列长度, 是批大小, 是值嵌入维度vdim。有关更多详细信息,请参阅“Attention Is All You Need”。key_padding_mask (Optional[Tensor]) – 如果指定,则为形状为 的掩码,指示要忽略
key中哪些元素以进行注意力计算(即视为“填充”)。对于未批处理的 query,形状应为 。支持二进制和浮点掩码。对于二进制掩码,True值表示相应的key值将被忽略以进行注意力计算。对于浮点掩码,它将直接添加到相应的key值。need_weights (bool) – 如果指定,则除了
attn_outputs之外,还会返回attn_output_weights。设置need_weights=False以使用优化的scaled_dot_product_attention并为 MHA 实现最佳性能。默认值:True。attn_mask (Optional[Tensor]) – 如果指定,则为阻止关注某些位置的 2D 或 3D 掩码。 形状必须为 或 ,其中 是批大小, 是目标序列长度, 是源序列长度。 2D 掩码将在批次中广播,而 3D 掩码允许批次中每个条目使用不同的掩码。 支持二进制和浮点掩码。 对于二进制掩码,
True值表示不允许关注相应位置。 对于浮点掩码,掩码值将添加到注意力权重中。 如果同时提供 attn_mask 和 key_padding_mask,则它们的类型应匹配。average_attn_weights (bool) – 如果为 true,则表示返回的
attn_weights应在 head 之间平均。 否则,attn_weights将按 head 单独提供。 请注意,此标志仅在need_weights=True时有效。 默认值:True(即,跨 head 平均权重)is_causal (bool) – 如果指定,则应用因果掩码作为注意力掩码。 默认值:
False。 警告:is_causal提供了一个提示,表明attn_mask是因果掩码。 提供不正确的提示可能会导致执行不正确,包括向前和向后兼容性。
- 返回类型
- 输出
attn_output - 注意力输出,当输入未批处理时,形状为 ,当
batch_first=False时,形状为 ,或者当batch_first=True时,形状为 ,其中 是目标序列长度, 是批大小, 是嵌入维度embed_dim。attn_output_weights - 仅当
need_weights=True时返回。 如果average_attn_weights=True,则返回跨 head 平均的注意力权重,当输入未批处理时,形状为 ,或者形状为 ,其中 是批大小, 是目标序列长度, 是源序列长度。 如果average_attn_weights=False,则返回每个 head 的注意力权重,当输入未批处理时,形状为 ,或者形状为 。
注
对于未批处理的输入,将忽略 batch_first 参数。
- merge_masks(attn_mask, key_padding_mask, query)[source][source]¶
确定掩码类型并在必要时组合掩码。
如果仅提供一个掩码,则将返回该掩码和相应的掩码类型。 如果同时提供了两个掩码,则它们都将扩展到形状
(batch_size, num_heads, seq_len, seq_len),与逻辑or组合,并返回掩码类型 2 :param attn_mask: 形状为(seq_len, seq_len)的注意力掩码,掩码类型 0 :param key_padding_mask: 形状为(batch_size, seq_len)的填充掩码,掩码类型 1 :param query: 形状为(batch_size, seq_len, embed_dim)的查询嵌入- 返回
merged mask mask_type: 合并的掩码类型 (0, 1, 或 2)
- 返回类型
merged_mask