快捷方式

FusionEmbedding

class torchtune.modules.model_fusion.FusionEmbedding(vocab_size: int, fusion_vocab_size: int, embed_dim: int)[source]

Fusion embedding 支持训练额外的特殊标记,同时保持原始嵌入冻结。当将新模型与语言模型融合时,可能需要一些额外的标记来支持融合后的语言模型。例如,添加视觉编码器可能需要像 <|image|> 这样的额外标记来指示图像在文本中的位置,并需要学习此标记的嵌入。FusionEmbedding 保持原始嵌入冻结,同时学习用于额外标记的更小的第二个嵌入。在正向传播期间,此模块会将标记路由到相应的嵌入表。

在您的模型中将此用作 torch.nn.Embedding 的直接替代品。

示例

>>> embedding = FusionEmbedding(vocab_size=100, fusion_vocab_size=10, embed_dim=128)
>>> model = TransformerDecoder(tok_embeddings=embedding, ...)
>>>
>>> # Original model state_dict still works
>>> model.load_state_dict(..., strict=False)

注意

此模块假设原始嵌入表中包含范围 [0, vocab_size) 中的所有标记,融合模型中所有新的标记都在范围 [vocab_size, vocab_size + fusion_vocab_size) 内。

参数:
  • vocab_size (int) – 语言模型词汇量大小

  • fusion_vocab_size (int) – 融合模型的额外标记

  • embed_dim (int) – 两个嵌入表的嵌入维度

forward(input: Tensor) Tensor[source]
参数:

input (torch.Tensor) – 形状为 [batch_size x seq_length] 的输入整数张量

返回值:

形状为

[batch_size x seq_length x embed_dim]` 的输出张量嵌入

返回类型:

Tensor

fusion_params() List[str][source]

返回融合嵌入参数。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源