多头注意力¶
- class torch.ao.nn.quantizable.MultiheadAttention(embed_dim, num_heads, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, batch_first=False, device=None, dtype=None)[source][source]¶
- forward(query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None, average_attn_weights=True, is_causal=False)[source][source]¶
- 注意:
更多信息请参考
forward()
- 参数
query (Tensor) – 将查询和一组键值对映射到输出。更多详细信息请参阅“Attention Is All You Need”论文。
key (Tensor) – 将查询和一组键值对映射到输出。更多详细信息请参阅“Attention Is All You Need”论文。
value (Tensor) – 将查询和一组键值对映射到输出。更多详细信息请参阅“Attention Is All You Need”论文。
key_padding_mask (Optional[Tensor]) – 如果提供,键中指定的填充元素将被注意力机制忽略。当给定一个二值掩码且值为 True 时,注意力层上的相应值将被忽略。
need_weights (bool) – 输出 attn_output_weights。
attn_mask (Optional[Tensor]) – 用于阻止注意力关注特定位置的二维或三维掩码。二维掩码将广播应用于所有批次,而三维掩码允许为每个批次的条目指定不同的掩码。
- 返回类型
- 形状
输入
query: 其中 L 是目标序列长度,N 是批次大小,E 是嵌入维度。 如果
batch_first
为True
。key: ,其中 S 是源序列长度,N 是批次大小,E 是嵌入维度。 如果
batch_first
为True
。value: 其中 S 是源序列长度,N 是批次大小,E 是嵌入维度。 如果
batch_first
为True
。key_padding_mask: 其中 N 是批次大小,S 是源序列长度。如果提供了 BoolTensor,值为
True
的位置将被忽略,而值为False
的位置保持不变。attn_mask: 二维掩码 其中 L 是目标序列长度,S 是源序列长度。三维掩码 其中 N 是批次大小,L 是目标序列长度,S 是源序列长度。attn_mask 确保位置 i 能够关注未被掩码的位置。如果提供了 BoolTensor,值为
True
的位置不允许关注,而值为False
的位置保持不变。如果提供了 FloatTensor,它将被添加到注意力权重中。is_causal: 如果指定,将应用因果掩码作为注意力掩码。与提供 attn_mask 互斥。默认值:
False
。average_attn_weights: 如果为 True,表示返回的
attn_weights
应在所有注意力头之间平均。否则,attn_weights
将按注意力头单独提供。请注意,此标志仅在need_weights=True
时有效。默认值:True(即在注意力头之间平均权重)输出
attn_output: 其中 L 是目标序列长度,N 是批次大小,E 是嵌入维度。 如果
batch_first
为True
。attn_output_weights: 如果
average_attn_weights=True
,返回在注意力头之间平均的注意力权重,形状为 ,其中 N 是批次大小,L 是目标序列长度,S 是源序列长度。如果average_attn_weights=False
,返回每个注意力头的注意力权重,形状为 。