快捷方式

FusionLayer

class torchtune.modules.model_fusion.FusionLayer(layer: Module, fusion_layer: Module, fusion_first: bool = True)[源代码]

Fusion layer,如 Flamingo: a Visual Language Model for Few-Shot Learning 中介绍的那样。

深度融合模型架构通过将编码器输出注入 LLM 的中间层,将预训练的编码器模型与预训练的语言模型相结合。这使得语言模型可以将编码器输出解释为文本,并“理解”您可以为其训练编码器的任何模态。为了使语言模型能够适应编码器输出,FusionLayer 将一个新的可学习层融合到现有的解码器(语言模型)层。这个额外的层可以获取编码器嵌入,并学习将它们与来自解码器的 token 嵌入相结合。该模块支持在原始层之前或之后融合新层,在 Flamingo 中,新层在原始层之前融合。

原始层被包裹在 FusionLayer 中,这样它就可以保持其原始 state_dict 键,并且预训练的检查点不会被破坏。新层参数可通过 fusion_params 获得,以单独控制它们是否可训练。

示例

>>> # Original decoder style transformer
>>> layer = nn.TransformerSelfAttentionLayer(...)
>>> model = TransformerDecoder(layers=layer, num_layers=32, ...)
>>>
>>> # Fuse a cross attention layer to each self attention layer to adapt for the encoder
>>> fusion_layer = nn.TransformerCrossAttentionLayer(...)
>>> fused_layer = FusionLayer(layer, fusion_layer)
>>> model = TransformerDecoder(layers=fused_layer, num_layers=32, ...)
>>>
>>> # Original decoder state_dict still works
>>> model.load_state_dict(..., strict=False)
参数:
  • layer (nn.Module) – 原始解码器层

  • fusion_layer (nn.Module) – 新的融合层

  • fusion_first (bool) – 布尔值,用于在解码器层之前或之后插入融合层。

caches_are_enabled() bool[源代码]

检查 self.layer 上的键值缓存是否已启用。请参阅 :func:~torchtune.modules.TransformerDecoder.caches_are_enabled`。

caches_are_setup() bool[源代码]

检查 self.layer 上的键值缓存是否已设置。请参阅 :func:~torchtune.modules.TransformerDecoder.caches_are_setup`。

forward(x: Tensor, **kwargs: Dict) Tensor[源代码]
参数:
  • x (torch.Tensor) – 输入张量,形状为 [batch_size x seq_length x embed_dim]

  • **kwargs (Dict) – 所有其他层参数

返回:

与输入形状相同的输出张量

[batch_size x seq_length x embed_dim]`

返回类型:

Tensor

fusion_params() List[str][源代码]

返回融合层的参数。

reset_cache()[源代码]

重置两个层的键值缓存。

setup_caches(batch_size: int, dtype: dtype, *, encoder_max_seq_len: int, decoder_max_seq_len: int) None[源代码]

为两个层设置键值缓存。

参数:
  • batch_size (int) – 缓存的批处理大小。

  • dtype (torch.dpython:type) – 缓存的数据类型。

  • encoder_max_seq_len (int) – 跨注意力层的最大缓存序列长度。

  • decoder_max_seq_len (int) – 自注意力层的最大缓存序列长度。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题解答

查看资源