快捷方式

validate_state_dict_for_lora

torchtune.modules.peft.validate_state_dict_for_lora(lora_attn_modules: List[Literal['q_proj', 'k_proj', 'v_proj', 'output_proj']], apply_lora_to_mlp: bool, apply_lora_to_output: bool, full_model_state_dict_keys: List[str], lora_state_dict_keys: Optional[List[str]] = None, base_model_state_dict_keys: Optional[List[str]] = None) None[源代码]

验证 LoRA 模型的状态字典键是否符合预期。

  1. 如果传递了 lora_state_dict_keys,则此函数将确认它们与完整模型中的 LoRA 参数名称完全匹配(由 lora_modules 确定)。

  2. 如果传递了 base_model_state_dict_keys,则此函数将确认它们是完整模型中 LoRA 参数名称的完全补集。

  3. 如果同时传递了 lora_state_dict_keys 和 base_model_state_dict_keys,则此函数将确认完整模型的参数恰好是它们的不相交并集。

参数:
  • lora_attn_modules (List[LORA_ATTN_MODULES]) – 每个自注意力块中应将 LoRA 应用于哪些线性层的列表。选项为 {"q_proj", "k_proj", "v_proj", "output_proj"}

  • apply_lora_to_mlp (bool) – 是否将 LoRA 应用于每个 MLP 线性层。

  • apply_lora_to_output (bool) – 是否将 LoRA 应用于最终输出投影。

  • full_model_state_dict_keys (List[str]) – 完整模型状态字典中的键列表。

  • lora_state_dict_keys (Optional[List[str]]) – LoRA 状态字典中的键列表。如果为 None,则不会验证 LoRA 状态字典键。

  • base_model_state_dict_keys (Optional[List[str]]) – 基础模型状态字典中的键列表。如果为 None,则不会验证基础模型键。

返回值:

引发:
  • AssertionError – 如果基础模型状态字典缺少完整模型中的任何非 LoRA 参数。

  • AssertionError – 如果 LoRA 状态字典缺少完整模型中的任何 LoRA 参数。

  • AssertionError – 如果基础模型状态字典包含任何 LoRA 参数。

  • AssertionError – 如果 LoRA 状态字典包含任何非 LoRA 参数。

  • AssertionError – 如果基础模型和 LoRA 状态字典具有重叠的键。

  • AssertionError – 如果完整模型状态字典缺少基础模型或 LoRA 状态字典中的任何键。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取初学者和高级开发人员的深度教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源