register_fusion_module¶
- torchtune.modules.model_fusion.register_fusion_module(module: Module)[source]¶
为 nn.Module 添加 fusion_params 方法,将模块的所有参数标记为融合参数。这可用于添加到组合两个或多个预训练模型的层或整个模型。
例如,您可能想在编码器上添加一个投影头 (projection head),以学习从预训练编码到解码器嵌入空间的投影。这在深度融合 (Deep Fusion) 和早期融合 (Early Fusion) 模型中都很常见。
示例
>>> projection_head = FeedForward(...) >>> register_fusion_module(projection_head)) >>> encoder = nn.Sequential(clip_vit_224(), projection_head)
- 参数:
module (nn.Module) – 要添加 fusion_params 方法的模块