TransformerDecoder¶
- class torch.nn.TransformerDecoder(decoder_layer, num_layers, norm=None)[源代码][源代码]¶
TransformerDecoder 是 N 个解码器层的堆栈。
注意
有关 PyTorch 为构建您自己的 Transformer 层提供的性能构建块的深入讨论,请参阅本教程。
- 参数
decoder_layer (TransformerDecoderLayer) – TransformerDecoderLayer() 类的实例(必需)。
num_layers (int) – 解码器中子解码器层的数量(必需)。
- 示例:
>>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8) >>> transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6) >>> memory = torch.rand(10, 32, 512) >>> tgt = torch.rand(20, 32, 512) >>> out = transformer_decoder(tgt, memory)
- forward(tgt, memory, tgt_mask=None, memory_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None, tgt_is_causal=None, memory_is_causal=False)[源代码][源代码]¶
依次将输入(和掩码)传递到解码器层。
- 参数
tgt (Tensor) – 要解码的序列(必需)。
memory (Tensor) – 来自编码器最后一层的序列(必需)。
tgt_key_padding_mask (Optional[Tensor]) – 每个批次 tgt 键的掩码(可选)。
memory_key_padding_mask (Optional[Tensor]) – 每个批次 memory 键的掩码(可选)。
tgt_is_causal (Optional[bool]) – 如果指定,则将因果掩码应用为
tgt mask
。默认值:None
;尝试检测因果掩码。警告:tgt_is_causal
提供了一个提示,表明tgt_mask
是因果掩码。提供不正确的提示可能会导致执行不正确,包括向前和向后兼容性。memory_is_causal (bool) – 如果指定,则将因果掩码应用为
memory mask
。默认值:False
。警告:memory_is_causal
提供了一个提示,表明memory_mask
是因果掩码。提供不正确的提示可能会导致执行不正确,包括向前和向后兼容性。
- 返回类型
- 形状
请参阅
Transformer
中的文档。