快捷方式

update_state_dict_for_classifier

torchtune.training.update_state_dict_for_classifier(state_dict: Dict[str, Tensor], model_named_parameters: Iterable[Tuple[str, Parameter]], force_override: bool = False)[来源]

验证用于分类器模型的检查点加载的状态字典。应在调用 model.load_state_dict(state_dict) 之前使用。如果待加载状态字典中的 output.weight 的形状与模型中的 output.weight 不匹配,此函数将使用模型中的 output.weight 覆盖待加载状态字典中的 output.weight。你可能也希望覆盖此行为,例如,如果你的检查点和模型的 num_classes 相同。

具体来说,当从基础语言模型的检查点微调分类器模型时,基础语言模型的 output.weight 形状为 [vocab_dim, embed_dim],我们会使用模型中随机初始化的 [num_classes, embed_dim] 权重覆盖待加载状态字典中的 output.weight。这是原地操作。

参数:
  • state_dict (Dict[str, torch.Tensor]) – 待加载到分类器模型中的状态字典。

  • model_named_parameters (Iterable[Tuple[str, torch.nn.Parameter]]) – 来自 model.named_parameters() 的模型命名参数。

  • force_override (bool) – 是否用模型的 output.weight 替换 state_dict 中的 output.weight,即使形状匹配。

注意

  • 如果 state_dict 中存在 output.bias,它将被忽略

  • 此函数将总是替换 state_dict 中的 output.weight

    如果 output.weightmodel.output.weight 不同。

抛出异常:

AssertionError – 如果 state_dict 不包含 output.weight,**或**如果 model_named_parameters 不包含 output.weight

文档

访问 PyTorch 全面的开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获取问题解答

查看资源