update_state_dict_for_classifier¶
- torchtune.training.update_state_dict_for_classifier(state_dict: Dict[str, Tensor], model_named_parameters: Iterable[Tuple[str, Parameter]], force_override: bool = False)[源代码]¶
验证分类器模型检查点加载的状态字典。在调用
model.load_state_dict(state_dict)
之前使用。如果output.weight
的形状不匹配,此函数将覆盖要加载的状态字典中的output.weight
,使其与模型中的output.weight
相同。您可能还希望覆盖此行为,例如,如果您的检查点和模型的num_classes
相同。具体而言,当从基础语言模型的检查点微调分类器模型时,该基础语言模型的
output.weight
的形状为[vocab_dim, embed_dim]
,我们将要加载的状态字典中的output.weight
替换为模型中随机初始化的[num_classes, embed_dim]
权重。这是就地完成的。- 参数:
state_dict (Dict[str, torch.Tensor]) – 要加载到分类器模型中的状态字典。
model_named_parameters (Iterable[Tuple[str, torch.nn.Parameter]]) – 来自
model.named_parameters()
的模型命名参数。force_override (bool) – 是否用模型的
output.weight
替换state_dict
中的output.weight
,即使形状匹配。
注释
如果
state_dict
中存在output.bias
,则会忽略它- 如果
output.weight != model.output.weight
,此函数将始终替换state_dict
中的output.weight
。 如果
output.weight != model.output.weight
。
- 如果
- 引发:
AssertionError – 如果
state_dict
不包含output.weight
。AssertionError – 如果
model_named_parameters
不包含output.weight
。