快捷方式

DeepFusionModel

class torchtune.modules.model_fusion.DeepFusionModel(decoder: TransformerDecoder, encoder: Module, *, decoder_trainable: bool = False, encoder_trainable: bool = False, fusion_trainable: bool = True)[源代码]

DeepFusion 是一种融合模型架构,其中预训练编码器与预训练解码器 (LLM) 在内部解码器层中结合。这是一种流行的多模态模型架构,完整概述可在《多模态模型架构的演变》中找到。一种常见的深度融合架构是通过间隔的交叉注意力层将编码器输入融合到解码器中。此模块对编码器和解码器如何融合不作任何假设;它只是将编码器嵌入传递给解码器,并让解码器处理任何融合。

此模块与 TransformerDecoder 具有相同的方法和 forward 签名,可以在任何使用 TransformerDecoder 的地方互换使用。它将编码器和解码器组合成一个单一模块,用于检查点和微调。期望编码器和解码器已定义好,并包含任何额外的可学习 fusion_params:用于帮助使预训练编码器适应预训练解码器的可学习参数。

DeepFusionModel 当前仅支持单个编码器。

示例: >>> # decoder 是一个带有融合交叉注意力层的 TransformerDecoder(例如 llama3_8b) >>> embed = FusionEmbedding(...) >>> layer = FusionLayer( … layer=TransformerSelfAttentionLayer(...), … fusion_layer=TransformerCrossAttentionLayer(...), … ) >>> decoder = TransformerDecoder(tok_embeddings=embed, layers=layer, num_layers=32, ...) >>> >>> # encoder 是一个带有附加投影头(projection head)的预训练编码器(例如 clip_vit_224) >>> projection_head = FeedForward(...) >>> register_fusion_module(projection_head)) >>> encoder = nn.Sequential(clip_vit_224(), projection_head) >>> >>> # DeepFusionModel 结合了编码器和解码器 >>> model = DeepFusionModel(decoder, encoder) >>> >>> # 加载完整的融合检查点(例如 Llama3.2 Vision 检查点) >>> model.load_state_dict(...) >>> >>> # 或者加载预训练的独立模型(不加载 fusion_params) >>> model.encoder.load_state_dict(..., strict=False) >>> model.decoder.load_state_dict(..., strict=False) >>> >>> # 前向传播 >>> output = model(tokens, mask, encoder_input, encoder_mask, input_pos)

参数:
  • decoder (TransformerDecoder) – 解码器模块

  • encoder (nn.Module) – 编码器模块

  • decoder_trainable (bool) – 是否训练或冻结解码器。默认为 False。

  • encoder_trainable (bool) – 是否训练或冻结编码器。默认为 False。

  • fusion_trainable (bool) – 是否训练融合参数。默认为 True。

caches_are_enabled() bool[源代码]

检查键值缓存是否启用。一旦 KV 缓存设置完毕,相关的注意力模块将被“启用”,并且所有前向传播都会更新缓存。可以通过使用 disable_kv_cache() “禁用” KV 缓存来禁用此行为,而无需改变 KV 缓存的状态,此时 caches_are_enabled 将返回 False。

caches_are_setup() bool[源代码]

检查键值缓存是否已设置。这意味着已调用 setup_caches,并且模型中相关的注意力模块已创建其 KVCache

forward(tokens: Tensor, *, mask: Optional[Tensor] = None, encoder_input: Optional[Dict] = None, encoder_mask: Optional[Tensor] = None, input_pos: Optional[Tensor] = None) Union[Tensor, List[Tensor]][源代码]
参数:
  • tokens (torch.Tensor) – 输入张量,形状为 [b x s]

  • mask (Optional[torch.Tensor]) – 可选的布尔张量,包含形状为 [b x s x s] 的注意力 mask。这在查询-键乘法之后、softmax 之前应用。第 i 行和第 j 列的值为 True 表示 token i 注意 token j。值为 False 表示 token i 不注意 token j。如果未指定 mask,则默认使用因果 mask。默认为 None。

  • encoder_input (Optional[Dict]) – 编码器的可选输入。

  • encoder_mask (Optional[torch.Tensor]) – 布尔张量,定义了 token 与编码器嵌入之间的关系矩阵。位置 i,j 的值为 True 表示 token i 可以在解码器中注意嵌入 j。Mask 形状为 [b x s x s_e]。默认为 None。

  • input_pos (Optional[torch.Tensor]) – 可选张量,包含每个 token 的位置 ID。在训练期间,用于指示打包时每个 token 相对于其样本的位置,形状为 [b x s]。在推理期间,指示当前 token 的位置。如果为 None,则假定 token 的索引为其位置 ID。默认为 None。

注意:在推理的第一步,当模型接收到 prompt 时,input_pos 将包含 prompt 中所有 token 的位置(例如:torch.arange(prompt_length))。这是因为我们需要计算每个位置的 KV 值。

返回值:

输出张量形状为 [b x s x v],或者由 output_hidden_states 定义的层输出张量列表,最终输出张量附加到列表末尾。

返回类型:

Tensor

张量形状使用的符号说明
  • b: 批处理大小

  • s: token 序列长度

  • s_e: 编码器序列长度

  • v: 词汇表大小

  • d: token 嵌入维度

  • d_e: 编码器嵌入维度

  • m_s: 最大序列长度

reset_caches()[源代码]

将相关注意力模块上的 KV 缓存缓冲区重置为零,并将缓存位置重置为零,而无需删除或重新分配缓存张量。

set_num_output_chunks(num_output_chunks: int) None[源代码]

CEWithChunkedOutputLoss 结合使用以节省内存。应在第一次前向传播之前在 recipe 中调用此函数。

setup_caches(batch_size: int, dtype: dtype, *, encoder_max_seq_len: Optional[int] = None, decoder_max_seq_len: Optional[int] = None)[源代码]

self.decoder 设置用于推理的键值注意力缓存。对于 self.decoder.layers 中的每一层:- torchtune.modules.TransformerSelfAttentionLayer 将使用 decoder_max_seq_len。- torchtune.modules.TransformerCrossAttentionLayer 将使用 encoder_max_seq_len。- torchtune.modules.fusion.FusionLayer 将同时使用 decoder_max_seq_lenencoder_max_seq_len

参数:
  • batch_size (int) – 缓存的批处理大小。

  • dtype (torch.dpython:type) – 缓存的数据类型。

  • encoder_max_seq_len (Optional[int]) – 编码器缓存最大序列长度。

  • decoder_max_seq_len (Optional[int]) – 解码器缓存最大序列长度。

文档

访问 PyTorch 的完整开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获取问题解答

查看资源