快捷方式

TransformerDecoder

class torch.nn.TransformerDecoder(decoder_layer, num_layers, norm=None)[source][source]

TransformerDecoder 是一个包含 N 个解码器层的堆栈。

注意

请参阅本教程,深入了解 PyTorch 提供的用于构建您自己的 Transformer 层的性能构建块。

参数
  • decoder_layer (TransformerDecoderLayer) – TransformerDecoderLayer() 类的一个实例(必需)。

  • num_layers (int) – 解码器中的子解码器层数(必需)。

  • norm (Optional[Module]) – 层归一化组件(可选)。

示例:
>>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
>>> transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)
>>> memory = torch.rand(10, 32, 512)
>>> tgt = torch.rand(20, 32, 512)
>>> out = transformer_decoder(tgt, memory)
forward(tgt, memory, tgt_mask=None, memory_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None, tgt_is_causal=None, memory_is_causal=False)[source][source]

依次通过解码器层传递输入(和掩码)。

参数
  • tgt (Tensor) – 传递给解码器的序列(必需)。

  • memory (Tensor) – 来自编码器最后一层的序列(必需)。

  • tgt_mask (Optional[Tensor]) – 用于 tgt 序列的掩码(可选)。

  • memory_mask (Optional[Tensor]) – 用于 memory 序列的掩码(可选)。

  • tgt_key_padding_mask (Optional[Tensor]) – 用于每个批次中 tgt 键的掩码(可选)。

  • memory_key_padding_mask (Optional[Tensor]) – 用于每个批次中 memory 键的掩码(可选)。

  • tgt_is_causal (Optional[bool]) – 如果指定,将应用因果掩码作为 tgt mask。默认值:None;尝试检测因果掩码。警告:tgt_is_causal 提供了一个提示,表明 tgt_mask 是因果掩码。提供不正确的提示可能导致错误的执行,包括前向和后向兼容性。

  • memory_is_causal (bool) – 如果指定,将应用因果掩码作为 memory mask。默认值:False。警告:memory_is_causal 提供了一个提示,表明 memory_mask 是因果掩码。提供不正确的提示可能导致错误的执行,包括前向和后向兼容性。

返回类型

Tensor

形状

参见 Transformer 中的文档。

文档

查阅 PyTorch 的完整开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题解答

查看资源