register_fusion_module¶
- torchtune.modules.model_fusion.register_fusion_module(module: Module)[source]¶
将方法 fusion_params 添加到 nn.Module 中,该方法将所有模块参数标记为融合参数。这可用于将两个或多个预训练模型组合在一起的层或整个模型。
例如,您可能希望在编码器上添加一个投影头 head,以学习从预训练编码到解码器嵌入空间的投影。这在深度融合和早期融合模型中都很常见。
示例
>>> projection_head = FeedForward(...) >>> register_fusion_module(projection_head)) >>> encoder = nn.Sequential(clip_vit_224(), projection_head)
- 参数:
module (nn.Module) – 要向其添加 fusion_params 方法的模块