快捷方式

update_state_dict_for_classifier

torchtune.training.update_state_dict_for_classifier(state_dict: Dict[str, Tensor], model_named_parameters: Iterable[Tuple[str, Parameter]], force_override: bool = False)[源代码]

验证用于分类器模型检查点加载的状态字典。在调用 model.load_state_dict(state_dict) 之前使用。如果 output.weight 的形状不匹配,此函数将覆盖要加载的状态字典中的 output.weight,使其与模型中的 output.weight 相同。您可能还希望覆盖此行为,例如,如果检查点和模型的 num_classes 相同。

具体来说,当从具有 output.weight 形状为 [vocab_dim, embed_dim] 的基础语言模型的检查点微调分类器模型时,我们将覆盖要加载的状态字典中的 output.weight,使其与模型中随机初始化的 [num_classes, embed_dim] 权重相同。这是就地完成的。

参数:
  • state_dict (Dict[str, torch.Tensor]) – 要加载到分类器模型中的状态字典。

  • model_named_parameters (Iterable[Tuple[str, torch.nn.Parameter]]) – 来自 model.named_parameters() 的模型命名参数。

  • force_override (bool) – 是否替换 state_dict 中的 output.weight,即使形状匹配。

备注

  • 如果 state_dict 中存在 output.bias,它将被忽略

  • 此函数将始终替换 state_dict 中的 output.weight

    如果 output.weight != model.output.weight

引发:
  • AssertionError – 如果 state_dict 不包含 output.weight

  • AssertionError – 如果 model_named_parameters 不包含 output.weight

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源