FusionEmbedding¶
- class torchtune.modules.model_fusion.FusionEmbedding(vocab_size: int, fusion_vocab_size: int, embed_dim: int)[源代码]¶
融合嵌入支持训练额外的特殊标记,同时保持原始嵌入冻结。当将新模型与语言模型融合时,可能需要一些额外的标记来支持融合后的语言模型。例如,添加一个视觉编码器可能需要像
<|image|>
这样的额外标记来指示图像在文本中的位置,并需要学习该标记的嵌入。FusionEmbedding 保持原始嵌入冻结,同时为额外标记学习一个更小的第二嵌入。在 forward 过程中,此模块将标记路由到相应的嵌入表。在你的模型中,可以将此作为
torch.nn.Embedding
的直接替换项使用。示例
>>> embedding = FusionEmbedding(vocab_size=100, fusion_vocab_size=10, embed_dim=128) >>> model = TransformerDecoder(tok_embeddings=embedding, ...) >>> >>> # Original model state_dict still works >>> model.load_state_dict(..., strict=False)
注意
此模块假设范围 [0, vocab_size) 内的所有标记都属于原始嵌入表,所有新标记都在范围 [vocab_size, vocab_size + fusion_vocab_size)
- forward(input: Tensor) Tensor [源代码]¶
- 参数:
input (torch.Tensor) – 输入整数张量,形状为 [batch_size x seq_length]
- 返回:
- 输出张量嵌入,形状为
[batch_size x seq_length x embed_dim]`
- 返回类型:
Tensor