TransformerEncoder¶
- class torch.nn.TransformerEncoder(encoder_layer, num_layers, norm=None, enable_nested_tensor=True, mask_check=True)[source]¶
TransformerEncoder 是 N 个编码器层的堆栈。
用户可以使用相应的参数构建 BERT(https://arxiv.org/abs/1810.04805) 模型。
- 参数
encoder_layer (TransformerEncoderLayer) – TransformerEncoderLayer() 类的实例(必需)。
num_layers (int) – 编码器中的子编码器层数(必需)。
enable_nested_tensor (bool) – 如果为 True,输入将自动转换为嵌套张量(并在输出时转换回来)。当填充率很高时,这将提高 TransformerEncoder 的整体性能。默认值:
True
(已启用)。
- 示例:
>>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8) >>> transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6) >>> src = torch.rand(10, 32, 512) >>> out = transformer_encoder(src)
- forward(src, mask=None, src_key_padding_mask=None, is_causal=None)[source]¶
依次通过编码器层传递输入。
- 参数
- 返回类型
- 形状
请参阅
Transformer
中的文档。