update_state_dict_for_classifier¶
- torchtune.training.update_state_dict_for_classifier(state_dict: Dict[str, Tensor], model_named_parameters: Iterable[Tuple[str, Parameter]], force_override: bool = False)[来源]¶
验证用于分类器模型的检查点加载的状态字典。应在调用
model.load_state_dict(state_dict)
之前使用。如果待加载状态字典中的output.weight
的形状与模型中的output.weight
不匹配,此函数将使用模型中的output.weight
覆盖待加载状态字典中的output.weight
。你可能也希望覆盖此行为,例如,如果你的检查点和模型的num_classes
相同。具体来说,当从基础语言模型的检查点微调分类器模型时,基础语言模型的
output.weight
形状为[vocab_dim, embed_dim]
,我们会使用模型中随机初始化的[num_classes, embed_dim]
权重覆盖待加载状态字典中的output.weight
。这是原地操作。- 参数:
注意
如果
state_dict
中存在output.bias
,它将被忽略- 此函数将总是替换
state_dict
中的output.weight
, 如果
output.weight
与model.output.weight
不同。
- 此函数将总是替换
- 抛出异常:
AssertionError – 如果
state_dict
不包含output.weight
,**或**如果model_named_parameters
不包含output.weight
。