torchtext.nn¶
MultiheadAttentionContainer¶
- class torchtext.nn.MultiheadAttentionContainer(nhead, in_proj_container, attention_layer, out_proj, batch_first=False)[source]¶
- __init__(nhead, in_proj_container, attention_layer, out_proj, batch_first=False) None [source]¶
多头注意力容器
- 参数:
nhead – 多头注意力模型中的头数
in_proj_container – 多头输入投影线性层(即 nn.Linear)的容器。
attention_layer – 自定义注意力层。从 MHA 容器发送到注意力层的输入形状为 (…, L, N * H, E / H)(用于查询)和 (…, S, N * H, E / H)(用于键/值),而注意力层的输出形状应为 (…, L, N * H, E / H)。如果用户希望整体的 MultiheadAttentionContainer 支持广播,则 attention_layer 需要支持广播。
out_proj – 多头输出投影层(即 nn.Linear)。
batch_first – 如果为
True
,则输入和输出张量将提供为 (…, N, L, E)。默认值:False
- 示例:
>>> import torch >>> from torchtext.nn import MultiheadAttentionContainer, InProjContainer, ScaledDotProduct >>> embed_dim, num_heads, bsz = 10, 5, 64 >>> in_proj_container = InProjContainer(torch.nn.Linear(embed_dim, embed_dim), torch.nn.Linear(embed_dim, embed_dim), torch.nn.Linear(embed_dim, embed_dim)) >>> MHA = MultiheadAttentionContainer(num_heads, in_proj_container, ScaledDotProduct(), torch.nn.Linear(embed_dim, embed_dim)) >>> query = torch.rand((21, bsz, embed_dim)) >>> key = value = torch.rand((16, bsz, embed_dim)) >>> attn_output, attn_weights = MHA(query, key, value) >>> print(attn_output.shape) >>> torch.Size([21, 64, 10])
- forward(query: Tensor, key: Tensor, value: Tensor, attn_mask: Optional[Tensor] = None, bias_k: Optional[Tensor] = None, bias_v: Optional[Tensor] = None) Tuple[Tensor, Tensor] [source]¶
- 参数:
query (Tensor) – 注意力函数的查询。有关更多详细信息,请参阅“Attention Is All You Need”。
key (Tensor) – 注意力函数的键。有关更多详细信息,请参阅“Attention Is All You Need”。
value (Tensor) – 注意力函数的值。有关更多详细信息,请参阅“Attention Is All You Need”。
attn_mask (BoolTensor, 可选) – 用于阻止对某些位置进行注意力的 3D 掩码。
bias_k (Tensor, 可选) – 另一个键值序列,将添加到序列维度 (dim=-3) 上的键中。这些用于增量解码。用户应提供
bias_v
。bias_v (Tensor, 可选) – 另一个键值序列,将添加到序列维度 (dim=-3) 上的值中。这些用于增量解码。用户也应提供
bias_k
。
形状
输入
query: \((..., L, N, E)\)
key: \((..., S, N, E)\)
value: \((..., S, N, E)\)
attn_mask、bias_k 和 bias_v:与注意力层中相应参数的形状相同。
输出
attn_output: \((..., L, N, E)\)
attn_output_weights: \((N * H, L, S)\)
注意:查询/键/值输入可以有超过三个维度(用于广播目的),这是可选的。MultiheadAttentionContainer 模块将在最后三个维度上进行操作。
其中 L 是目标长度,S 是序列长度,H 是注意力头的数量,N 是批次大小,E 是嵌入维度。
InProjContainer¶
- class torchtext.nn.InProjContainer(query_proj, key_proj, value_proj)[source]¶
- __init__(query_proj, key_proj, value_proj) None [source]¶
用于在MultiheadAttention中投影查询/键/值的内部投影容器。此模块发生在将投影后的查询/键/值重塑为多个头之前。请参阅“Attention Is All You Need”论文图2中多头注意力(底部)的线性层。还可以查看torchtext.nn.MultiheadAttentionContainer中的使用示例。
- 参数:
query_proj – 查询的投影层。一个典型的投影层是torch.nn.Linear。
key_proj – 键的投影层。一个典型的投影层是torch.nn.Linear。
value_proj – 值的投影层。一个典型的投影层是torch.nn.Linear。
- forward(query: Tensor, key: Tensor, value: Tensor) Tuple[Tensor, Tensor, Tensor] [source]¶
使用内部投影层投影输入序列。查询/键/值分别简单地传递到query/key/value_proj的前向函数。
- 参数:
query (Tensor) – 要投影的查询。
key (Tensor) – 要投影的键。
value (Tensor) – 要投影的值。
- 示例:
>>> import torch >>> from torchtext.nn import InProjContainer >>> embed_dim, bsz = 10, 64 >>> in_proj_container = InProjContainer(torch.nn.Linear(embed_dim, embed_dim), torch.nn.Linear(embed_dim, embed_dim), torch.nn.Linear(embed_dim, embed_dim)) >>> q = torch.rand((5, bsz, embed_dim)) >>> k = v = torch.rand((6, bsz, embed_dim)) >>> q, k, v = in_proj_container(q, k, v)
ScaledDotProduct¶
- class torchtext.nn.ScaledDotProduct(dropout=0.0, batch_first=False)[source]¶
- __init__(dropout=0.0, batch_first=False) None [source]¶
处理投影后的查询和键值对以应用缩放点积注意力。
- 参数:
dropout (float) – 丢弃注意力权重的概率。
batch_first – 如果为
True
,则输入和输出张量将作为(batch, seq, feature)提供。默认值:False
- 示例:
>>> import torch, torchtext >>> SDP = torchtext.nn.ScaledDotProduct(dropout=0.1) >>> q = torch.randn(21, 256, 3) >>> k = v = torch.randn(21, 256, 3) >>> attn_output, attn_weights = SDP(q, k, v) >>> print(attn_output.shape, attn_weights.shape) torch.Size([21, 256, 3]) torch.Size([256, 21, 21])
- forward(query: Tensor, key: Tensor, value: Tensor, attn_mask: Optional[Tensor] = None, bias_k: Optional[Tensor] = None, bias_v: Optional[Tensor] = None) Tuple[Tensor, Tensor] [source]¶
使用与投影后的键值对的缩放点积来更新投影后的查询。
- 参数:
query (Tensor) – 投影后的查询
key (Tensor) – 投影后的键
value (Tensor) – 投影后的值
attn_mask (BoolTensor, 可选) – 用于阻止对某些位置进行注意力的 3D 掩码。
attn_mask – 3D掩码,用于阻止对某些位置的注意力。
bias_k (Tensor, 可选) – 另一个键值序列,将添加到序列维度 (dim=-3) 上的键中。这些用于增量解码。用户应提供
bias_v
。bias_v (Tensor, 可选) – 另一个键值序列,将添加到序列维度 (dim=-3) 上的值中。这些用于增量解码。用户也应提供
bias_k
。
- 形状
query: \((..., L, N * H, E / H)\)
key: \((..., S, N * H, E / H)\)
value: \((..., S, N * H, E / H)\)
- attn_mask: \((N * H, L, S)\), 值为
True
的位置不允许参与注意力 而值为
False
的位置将保持不变。
- attn_mask: \((N * H, L, S)\), 值为
bias_k 和 bias_v:bias: \((1, N * H, E / H)\)
输出: \((..., L, N * H, E / H)\), \((N * H, L, S)\)
- 注意:查询/键/值输入可以有超过三个维度(用于广播目的),这是可选的。
ScaledDotProduct模块将对最后三个维度进行操作。
其中L是目标长度,S是源长度,H是注意力头的数量,N是批次大小,E是嵌入维度。