update_state_dict_for_classifier¶
- torchtune.training.update_state_dict_for_classifier(state_dict: Dict[str, Tensor], model_named_parameters: Iterable[Tuple[str, Parameter]], force_override: bool = False)[源代码]¶
验证用于分类器模型检查点加载的状态字典。在调用
model.load_state_dict(state_dict)
之前使用。如果output.weight
的形状不匹配,此函数将覆盖要加载的状态字典中的output.weight
,使其与模型中的output.weight
相同。您可能还希望覆盖此行为,例如,如果检查点和模型的num_classes
相同。具体来说,当从具有
output.weight
形状为[vocab_dim, embed_dim]
的基础语言模型的检查点微调分类器模型时,我们将覆盖要加载的状态字典中的output.weight
,使其与模型中随机初始化的[num_classes, embed_dim]
权重相同。这是就地完成的。- 参数:
state_dict (Dict[str, torch.Tensor]) – 要加载到分类器模型中的状态字典。
model_named_parameters (Iterable[Tuple[str, torch.nn.Parameter]]) – 来自
model.named_parameters()
的模型命名参数。force_override (bool) – 是否替换
state_dict
中的output.weight
,即使形状匹配。
备注
如果
state_dict
中存在output.bias
,它将被忽略- 此函数将始终替换
state_dict
中的output.weight
, 如果
output.weight != model.output.weight
。
- 此函数将始终替换
- 引发:
AssertionError – 如果
state_dict
不包含output.weight
。AssertionError – 如果
model_named_parameters
不包含output.weight
。