快捷方式

多头注意力机制

class torch.ao.nn.quantizable.MultiheadAttention(embed_dim, num_heads, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, batch_first=False, device=None, dtype=None)[源代码]
dequantize()[源代码]

将量化的 MHA 转换回浮点数的实用程序。

这样做的动机是,将量化版本中使用的权重从其格式转换回浮点数并非易事。

forward(query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None, average_attn_weights=True, is_causal=False)[源代码]
注意:

请参阅 forward() 获取更多信息。

参数
  • query (张量) – 将查询和一组键值对映射到输出。有关更多详细信息,请参阅“Attention Is All You Need”。

  • key (张量) – 将查询和一组键值对映射到输出。有关更多详细信息,请参阅“Attention Is All You Need”。

  • value (张量) – 将查询和一组键值对映射到输出。有关更多详细信息,请参阅“Attention Is All You Need”。

  • key_padding_mask (可选[张量]) – 如果提供,则注意力机制将忽略键中的指定填充元素。当给出二进制掩码且值为 True 时,注意力层上的对应值将被忽略。

  • need_weights (布尔值) – 输出 attn_output_weights。

  • attn_mask (可选[张量]) – 2D 或 3D 掩码,用于阻止对某些位置的注意力。2D 掩码将广播到所有批次,而 3D 掩码允许为每个批次的条目指定不同的掩码。

返回值类型

元组[张量, 可选[张量]]

形状
  • 输入

  • 查询:(L,N,E)(L, N, E),其中 L 是目标序列长度,N 是批次大小,E 是嵌入维度。 (N,L,E)(N, L, E) 如果 batch_firstTrue

  • 键:(S,N,E)(S, N, E),其中 S 是源序列长度,N 是批次大小,E 是嵌入维度。 (N,S,E)(N, S, E) 如果 batch_firstTrue

  • 值:(S,N,E)(S, N, E),其中 S 是源序列长度,N 是批次大小,E 是嵌入维度。 (N,S,E)(N, S, E) 如果 batch_firstTrue

  • 键填充掩码:(N,S)(N, S),其中 N 是批次大小,S 是源序列长度。如果提供 BoolTensor,则值 为 True 的位置将被忽略,而值为 False 的位置将保持不变。

  • 注意力掩码:2D 掩码 (L,S)(L, S),其中 L 是目标序列长度,S 是源序列长度。3D 掩码 (Nnumheads,L,S)(N*num_heads, L, S),其中 N 是批次大小,L 是目标序列长度,S 是源序列长度。注意力掩码确保位置 i 可以关注未掩码的位置。如果提供 BoolTensor,值为 True 的位置将不允许关注,而值为 False 的位置将保持不变。如果提供 FloatTensor,它将被添加到注意力权重中。

  • 因果关系:如果指定,则应用因果关系掩码作为注意力掩码。与提供 attn_mask 相互排斥。默认值:False

  • 平均注意力权重:如果为真,则表示返回的 attn_weights 应在所有头部上取平均值。否则,attn_weights 将按头部分别提供。请注意,此标志仅在 need_weights=True. 时有效。默认值:True(即在所有头部上平均权重)。

  • 输出

  • attn_output:(L,N,E)(L, N, E),其中 L 是目标序列长度,N 是批次大小,E 是嵌入维度。 (N,L,E)(N, L, E) 如果 batch_firstTrue

  • attn_output_weights:如果 average_attn_weights=True,则返回在所有头部上取平均值的注意力权重,形状为 (N,L,S)(N, L, S),其中 N 是批次大小,L 是目标序列长度,S 是源序列长度。如果 average_attn_weights=False,则返回每个头部的注意力权重,形状为 (N,numheads,L,S)(N, num_heads, L, S).

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源