FusionEmbedding¶
- class torchtune.modules.model_fusion.FusionEmbedding(vocab_size: int, fusion_vocab_size: int, embed_dim: int)[源代码]¶
Fusion embedding 支持训练额外的特殊 token,同时保持原始 embedding 冻结。当融合新模型和语言模型时,可能需要一些额外的 token 来支持融合的语言模型。例如,添加视觉编码器可能需要额外的 token,例如
<|image|>
来指示图像在文本中的位置,并需要学习此 token 的 embedding。FusionEmbedding 保持原始 embedding 冻结,同时为额外的 token 学习一个更小的第二个 embedding。在正向传播期间,此模块将 token 路由到适当的 embedding 表。使用此模块作为
torch.nn.Embedding
的直接替代品。示例
>>> embedding = FusionEmbedding(vocab_size=100, fusion_vocab_size=10, embed_dim=128) >>> model = TransformerDecoder(tok_embeddings=embedding, ...) >>> >>> # Original model state_dict still works >>> model.load_state_dict(..., strict=False)
注意
此模块假定范围 [0, vocab_size) 中的所有 token 都是原始 embedding 表的一部分,而范围 [vocab_size, vocab_size + fusion_vocab_size) 中的所有新 token 都是额外的 token。
- 参数:
- forward(input: Tensor) Tensor [源代码]¶
- 参数:
input (torch.Tensor) – 形状为 [batch_size x seq_length] 的输入整数张量
- 返回:
- 形状为 [batch_size x seq_length x embed_dim] 的输出张量 embedding
[batch_size x seq_length x embed_dim]`
- 返回类型:
Tensor