快捷方式

update_state_dict_for_classifier

torchtune.training.update_state_dict_for_classifier(state_dict: Dict[str, Tensor], model_named_parameters: Iterable[Tuple[str, Parameter]], force_override: bool = False)[源代码]

验证分类器模型检查点加载的状态字典。在调用 model.load_state_dict(state_dict) 之前使用。如果 output.weight 的形状不匹配,此函数将覆盖要加载的状态字典中的 output.weight,使其与模型中的 output.weight 相同。您可能还希望覆盖此行为,例如,如果您的检查点和模型的 num_classes 相同。

具体而言,当从基础语言模型的检查点微调分类器模型时,该基础语言模型的 output.weight 的形状为 [vocab_dim, embed_dim],我们将要加载的状态字典中的 output.weight 替换为模型中随机初始化的 [num_classes, embed_dim] 权重。这是就地完成的。

参数:
  • state_dict (Dict[str, torch.Tensor]) – 要加载到分类器模型中的状态字典。

  • model_named_parameters (Iterable[Tuple[str, torch.nn.Parameter]]) – 来自 model.named_parameters() 的模型命名参数。

  • force_override (bool) – 是否用模型的 output.weight 替换 state_dict 中的 output.weight,即使形状匹配。

注释

  • 如果 state_dict 中存在 output.bias,则会忽略它

  • 如果 output.weight != model.output.weight,此函数将始终替换 state_dict 中的 output.weight

    如果 output.weight != model.output.weight

引发:
  • AssertionError – 如果 state_dict 不包含 output.weight

  • AssertionError – 如果 model_named_parameters 不包含 output.weight

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得问题解答

查看资源