MultiheadAttention¶
- class torch.nn.MultiheadAttention(embed_dim, num_heads, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, batch_first=False, device=None, dtype=None)[source][source]¶
使模型能够同时关注来自不同表示子空间的信息。
注意
请参阅本教程,深入讨论 PyTorch 为构建您自己的 Transformer 层提供的性能构建块。
论文中描述的方法:Attention Is All You Need。
多头注意力定义为
其中 。
nn.MultiheadAttention
在可能的情况下将使用scaled_dot_product_attention()
的优化实现。除了支持新的
scaled_dot_product_attention()
函数外,为了加速推理,如果满足以下条件,MHA 将使用支持 Nested Tensors 的 fastpath 推理:正在计算自注意力(即
query
、key
和value
是同一个 tensor)。输入是 batch 形式 (3D),且
batch_first==True
Autograd 已禁用(使用
torch.inference_mode
或torch.no_grad
),或者没有 tensor 参数requires_grad
训练已禁用(使用
.eval()
)add_bias_kv
为False
add_zero_attn
为False
kdim
和vdim
等于embed_dim
如果传入了 NestedTensor,则不会传入
key_padding_mask
或attn_mask
autocast 已禁用
如果使用了优化的推理 fastpath 实现,可以将 NestedTensor 传入
query
/key
/value
以比使用填充掩码更高效地表示填充。在这种情况下,将返回 NestedTensor,并且可以预期加速与输入中填充的比例成正比。- 参数
embed_dim – 模型的总维度。
num_heads – 并行注意力头的数量。注意
embed_dim
将被分割到num_heads
个头中(即每个头的维度将是embed_dim // num_heads
)。dropout –
attn_output_weights
上的 dropout 概率。默认值:0.0
(无 dropout)。bias – 如果指定,则向输入/输出投影层添加偏置。默认值:
True
。add_bias_kv – 如果指定,则向 key 和 value 序列的 dim=0 添加偏置。默认值:
False
。add_zero_attn – 如果指定,则向 key 和 value 序列的 dim=1 添加一批新的零。默认值:
False
。kdim – key 的特征总数。默认值:
None
(使用kdim=embed_dim
)。vdim – value 的特征总数。默认值:
None
(使用vdim=embed_dim
)。batch_first – 如果为
True
,则输入和输出 tensor 以 (batch, seq, feature) 形式提供。默认值:False
(seq, batch, feature)。
示例
>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) >>> attn_output, attn_output_weights = multihead_attn(query, key, value)
- forward(query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None, average_attn_weights=True, is_causal=False)[source][source]¶
使用 query、key 和 value 嵌入计算注意力输出。
支持用于填充、掩码和注意力权重的可选参数。
- 参数
query (Tensor) – query 嵌入的形状为 (非 batch 输入),(当
batch_first=False
时)或 (当batch_first=True
时),其中 是目标序列长度, 是 batch 大小, 是 query 嵌入维度embed_dim
。query 与 key-value 对进行比较以产生输出。更多详细信息请参阅“Attention Is All You Need”。key (Tensor) – key 嵌入的形状为 (非 batch 输入),(当
batch_first=False
时)或 (当batch_first=True
时),其中 是源序列长度, 是 batch 大小, 是 key 嵌入维度kdim
。更多详细信息请参阅“Attention Is All You Need”。value (Tensor) – 值嵌入,未批处理输入的形状为 ,当
batch_first=False
时形状为 ,当batch_first=True
时形状为 。其中 是源序列长度, 是批量大小, 是值嵌入维度vdim
。更多详细信息请参阅“Attention Is All You Need”。key_padding_mask (Optional[Tensor]) – 如果指定,则为一个形状为 的掩码,指示
key
中哪些元素在注意力计算中应被忽略(即视为“填充”)。对于未批处理的 query,形状应为 。支持二进制掩码和浮点数掩码。对于二进制掩码,值为True
表示相应的key
值在注意力计算中将被忽略。对于浮点数掩码,其值将直接加到相应的key
值上。need_weights (bool) – 如果指定,除了
attn_outputs
外,还返回attn_output_weights
。设置need_weights=False
可以使用优化的scaled_dot_product_attention
,并为 MHA 实现最佳性能。默认值:True
。attn_mask (Optional[Tensor]) – 如果指定,则为一个 2D 或 3D 掩码,用于阻止对某些位置进行注意力计算。形状必须为 或 ,其中 是批量大小, 是目标序列长度, 是源序列长度。2D 掩码将在整个批次上广播,而 3D 掩码允许为批次中的每个条目使用不同的掩码。支持二进制掩码和浮点数掩码。对于二进制掩码,值为
True
表示不允许对相应的此位置进行注意力计算。对于浮点数掩码,掩码值将加到注意力权重上。如果同时提供了 attn_mask 和 key_padding_mask,它们的类型应匹配。average_attn_weights (bool) – 如果为 True,表示返回的
attn_weights
应在所有注意力头之间取平均值。否则,attn_weights
将单独提供每个注意力头的结果。请注意,此标志仅在need_weights=True
时有效。默认值:True
(即在注意力头之间取平均值)is_causal (bool) – 如果指定,则将因果掩码作为注意力掩码应用。默认值:
False
。警告:is_causal
提供了一个提示,表明attn_mask
是因果掩码。提供不正确的提示可能导致执行错误,包括前向和后向兼容性问题。
- 返回类型
- 输出
attn_output - 注意力输出,输入未批处理时形状为 ,当
batch_first=False
时形状为 ,当batch_first=True
时形状为 。其中 是目标序列长度, 是批量大小, 是嵌入维度embed_dim
。attn_output_weights - 仅在
need_weights=True
时返回。如果average_attn_weights=True
,返回在注意力头之间平均后的注意力权重,输入未批处理时形状为 或形状为 。其中 是批量大小, 是目标序列长度, 是源序列长度。如果average_attn_weights=False
,返回每个注意力头的注意力权重,输入未批处理时形状为 或形状为 。
注意
对于未批处理输入,batch_first 参数将被忽略。
- merge_masks(attn_mask, key_padding_mask, query)[source][source]¶
确定掩码类型并在必要时组合掩码。
如果仅提供一个掩码,将返回该掩码及相应的掩码类型。如果同时提供了两个掩码,它们都将被扩展到形状
(batch_size, num_heads, seq_len, seq_len)
,使用逻辑or
进行组合,并返回掩码类型 2。
:param attn_mask: 注意力掩码,形状为(seq_len, seq_len)
,掩码类型 0
:param key_padding_mask: 填充掩码,形状为(batch_size, seq_len)
,掩码类型 1
:param query: query 嵌入,形状为(batch_size, seq_len, embed_dim)
- 返回
合并后的掩码 mask_type: 合并后的掩码类型 (0, 1 或 2)
- 返回类型
merged_mask