快捷方式

register_fusion_module

torchtune.modules.model_fusion.register_fusion_module(module: Module)[source]

为 nn.Module 添加 fusion_params 方法,将模块的所有参数标记为融合参数。这可用于添加到组合两个或多个预训练模型的层或整个模型。

例如,您可能想在编码器上添加一个投影头 (projection head),以学习从预训练编码到解码器嵌入空间的投影。这在深度融合 (Deep Fusion) 和早期融合 (Early Fusion) 模型中都很常见。

示例

>>> projection_head = FeedForward(...)
>>> register_fusion_module(projection_head))
>>> encoder = nn.Sequential(clip_vit_224(), projection_head)
参数

module (nn.Module) – 要添加 fusion_params 方法的模块


© 版权所有 2023-至今,torchtune 贡献者。

使用 Sphinx 构建,主题由 Read the Docs 提供。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得问题解答

查看资源