快捷方式

torchtext.nn

MultiheadAttentionContainer

class torchtext.nn.MultiheadAttentionContainer(nhead, in_proj_container, attention_layer, out_proj, batch_first=False)[source]
__init__(nhead, in_proj_container, attention_layer, out_proj, batch_first=False) None[source]

多头注意力容器

参数:
  • nhead – 多头注意力模型中的头数

  • in_proj_container – 多头输入投影线性层(即 nn.Linear)的容器。

  • attention_layer – 自定义注意力层。从 MHA 容器发送到注意力层的输入形状为 (…, L, N * H, E / H)(用于查询)和 (…, S, N * H, E / H)(用于键/值),而注意力层的输出形状应为 (…, L, N * H, E / H)。如果用户希望整体的 MultiheadAttentionContainer 支持广播,则 attention_layer 需要支持广播。

  • out_proj – 多头输出投影层(即 nn.Linear)。

  • batch_first – 如果为 True,则输入和输出张量将提供为 (…, N, L, E)。默认值:False

示例:
>>> import torch
>>> from torchtext.nn import MultiheadAttentionContainer, InProjContainer, ScaledDotProduct
>>> embed_dim, num_heads, bsz = 10, 5, 64
>>> in_proj_container = InProjContainer(torch.nn.Linear(embed_dim, embed_dim),
                                        torch.nn.Linear(embed_dim, embed_dim),
                                        torch.nn.Linear(embed_dim, embed_dim))
>>> MHA = MultiheadAttentionContainer(num_heads,
                                      in_proj_container,
                                      ScaledDotProduct(),
                                      torch.nn.Linear(embed_dim, embed_dim))
>>> query = torch.rand((21, bsz, embed_dim))
>>> key = value = torch.rand((16, bsz, embed_dim))
>>> attn_output, attn_weights = MHA(query, key, value)
>>> print(attn_output.shape)
>>> torch.Size([21, 64, 10])
forward(query: Tensor, key: Tensor, value: Tensor, attn_mask: Optional[Tensor] = None, bias_k: Optional[Tensor] = None, bias_v: Optional[Tensor] = None) Tuple[Tensor, Tensor][source]
参数:
  • query (Tensor) – 注意力函数的查询。有关更多详细信息,请参阅“Attention Is All You Need”。

  • key (Tensor) – 注意力函数的键。有关更多详细信息,请参阅“Attention Is All You Need”。

  • value (Tensor) – 注意力函数的值。有关更多详细信息,请参阅“Attention Is All You Need”。

  • attn_mask (BoolTensor, 可选) – 用于阻止对某些位置进行注意力的 3D 掩码。

  • bias_k (Tensor, 可选) – 另一个键值序列,将添加到序列维度 (dim=-3) 上的键中。这些用于增量解码。用户应提供 bias_v

  • bias_v (Tensor, 可选) – 另一个键值序列,将添加到序列维度 (dim=-3) 上的值中。这些用于增量解码。用户也应提供 bias_k

形状

  • 输入

    • query: \((..., L, N, E)\)

    • key: \((..., S, N, E)\)

    • value: \((..., S, N, E)\)

    • attn_mask、bias_k 和 bias_v:与注意力层中相应参数的形状相同。

  • 输出

    • attn_output: \((..., L, N, E)\)

    • attn_output_weights: \((N * H, L, S)\)

注意:查询/键/值输入可以有超过三个维度(用于广播目的),这是可选的。MultiheadAttentionContainer 模块将在最后三个维度上进行操作。

其中 L 是目标长度,S 是序列长度,H 是注意力头的数量,N 是批次大小,E 是嵌入维度。

InProjContainer

class torchtext.nn.InProjContainer(query_proj, key_proj, value_proj)[source]
__init__(query_proj, key_proj, value_proj) None[source]

用于在MultiheadAttention中投影查询/键/值的内部投影容器。此模块发生在将投影后的查询/键/值重塑为多个头之前。请参阅“Attention Is All You Need”论文图2中多头注意力(底部)的线性层。还可以查看torchtext.nn.MultiheadAttentionContainer中的使用示例。

参数:
  • query_proj – 查询的投影层。一个典型的投影层是torch.nn.Linear。

  • key_proj – 键的投影层。一个典型的投影层是torch.nn.Linear。

  • value_proj – 值的投影层。一个典型的投影层是torch.nn.Linear。

forward(query: Tensor, key: Tensor, value: Tensor) Tuple[Tensor, Tensor, Tensor][source]

使用内部投影层投影输入序列。查询/键/值分别简单地传递到query/key/value_proj的前向函数。

参数:
  • query (Tensor) – 要投影的查询。

  • key (Tensor) – 要投影的键。

  • value (Tensor) – 要投影的值。

示例:
>>> import torch
>>> from torchtext.nn import InProjContainer
>>> embed_dim, bsz = 10, 64
>>> in_proj_container = InProjContainer(torch.nn.Linear(embed_dim, embed_dim),
                                        torch.nn.Linear(embed_dim, embed_dim),
                                        torch.nn.Linear(embed_dim, embed_dim))
>>> q = torch.rand((5, bsz, embed_dim))
>>> k = v = torch.rand((6, bsz, embed_dim))
>>> q, k, v = in_proj_container(q, k, v)

ScaledDotProduct

class torchtext.nn.ScaledDotProduct(dropout=0.0, batch_first=False)[source]
__init__(dropout=0.0, batch_first=False) None[source]

处理投影后的查询和键值对以应用缩放点积注意力。

参数:
  • dropout (float) – 丢弃注意力权重的概率。

  • batch_first – 如果为True,则输入和输出张量将作为(batch, seq, feature)提供。默认值:False

示例:
>>> import torch, torchtext
>>> SDP = torchtext.nn.ScaledDotProduct(dropout=0.1)
>>> q = torch.randn(21, 256, 3)
>>> k = v = torch.randn(21, 256, 3)
>>> attn_output, attn_weights = SDP(q, k, v)
>>> print(attn_output.shape, attn_weights.shape)
torch.Size([21, 256, 3]) torch.Size([256, 21, 21])
forward(query: Tensor, key: Tensor, value: Tensor, attn_mask: Optional[Tensor] = None, bias_k: Optional[Tensor] = None, bias_v: Optional[Tensor] = None) Tuple[Tensor, Tensor][source]

使用与投影后的键值对的缩放点积来更新投影后的查询。

参数:
  • query (Tensor) – 投影后的查询

  • key (Tensor) – 投影后的键

  • value (Tensor) – 投影后的值

  • attn_mask (BoolTensor, 可选) – 用于阻止对某些位置进行注意力的 3D 掩码。

  • attn_mask – 3D掩码,用于阻止对某些位置的注意力。

  • bias_k (Tensor, 可选) – 另一个键值序列,将添加到序列维度 (dim=-3) 上的键中。这些用于增量解码。用户应提供 bias_v

  • bias_v (Tensor, 可选) – 另一个键值序列,将添加到序列维度 (dim=-3) 上的值中。这些用于增量解码。用户也应提供 bias_k

形状
  • query: \((..., L, N * H, E / H)\)

  • key: \((..., S, N * H, E / H)\)

  • value: \((..., S, N * H, E / H)\)

  • attn_mask: \((N * H, L, S)\), 值为True的位置不允许参与注意力

    而值为False的位置将保持不变。

  • bias_k 和 bias_v:bias: \((1, N * H, E / H)\)

  • 输出: \((..., L, N * H, E / H)\), \((N * H, L, S)\)

注意:查询/键/值输入可以有超过三个维度(用于广播目的),这是可选的。

ScaledDotProduct模块将对最后三个维度进行操作。

其中L是目标长度,S是源长度,H是注意力头的数量,N是批次大小,E是嵌入维度。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源