FusionLayer¶
- class torchtune.modules.model_fusion.FusionLayer(layer: Module, fusion_layer: Module, fusion_first: bool = True)[source]¶
Fusion 层,如 Flamingo: a Visual Language Model for Few-Shot Learning 中介绍的。
深度融合模型架构通过将编码器输出注入到 LLM 的中间层中,将预训练的编码器模型与预训练的语言模型结合起来。这使得语言模型可以将编码器输出解释为文本,并“理解”任何可以训练编码器的模态。为了使语言模型适应编码器输出,FusionLayer 将一个新的可学习层融合到一个现有的解码器(语言模型)层中。这个附加层可以接收编码器嵌入,并学习将它们与来自解码器的 token 嵌入结合起来。该模块支持将新层融合在原始层之前或之后,在 Flamingo 中,新层融合在原始层之前。
原始层被包装在 FusionLayer 中,以便保持其原始的 state_dict key,并且不会破坏预训练的检查点。新层的参数通过
fusion_params
可用,可以单独控制它们是否可训练。示例
>>> # Original decoder style transformer >>> layer = nn.TransformerSelfAttentionLayer(...) >>> model = TransformerDecoder(layers=layer, num_layers=32, ...) >>> >>> # Fuse a cross attention layer to each self attention layer to adapt for the encoder >>> fusion_layer = nn.TransformerCrossAttentionLayer(...) >>> fused_layer = FusionLayer(layer, fusion_layer) >>> model = TransformerDecoder(layers=fused_layer, num_layers=32, ...) >>> >>> # Original decoder state_dict still works >>> model.load_state_dict(..., strict=False)
- 参数:
layer (nn.Module) – 原始解码器层
fusion_layer (nn.Module) – 新的融合层
fusion_first (bool) – 布尔值,指示在解码器层之前或之后插入融合层。
- caches_are_enabled() bool [source]¶
检查
self.layer
上的键值缓存是否已启用。请参阅 :func:~torchtune.modules.TransformerDecoder.caches_are_enabled`。
- caches_are_setup() bool [source]¶
检查
self.layer
上的键值缓存是否已设置。请参阅 :func:~torchtune.modules.TransformerDecoder.caches_are_setup`。