validate_state_dict_for_lora¶
- torchtune.modules.peft.validate_state_dict_for_lora(lora_attn_modules: List[Literal['q_proj', 'k_proj', 'v_proj', 'output_proj']], apply_lora_to_mlp: bool, apply_lora_to_output: bool, full_model_state_dict_keys: List[str], lora_state_dict_keys: Optional[List[str]] = None, base_model_state_dict_keys: Optional[List[str]] = None) None [源代码]¶
验证 LoRA 模型的状态字典键是否符合预期。
如果传递了 lora_state_dict_keys,则此函数将确认它们与完整模型中的 LoRA 参数名称完全匹配(由 lora_modules 确定)。
如果传递了 base_model_state_dict_keys,则此函数将确认它们是完整模型中 LoRA 参数名称的完全补集。
如果同时传递了 lora_state_dict_keys 和 base_model_state_dict_keys,则此函数将确认完整模型的参数恰好是它们的不相交并集。
- 参数:
lora_attn_modules (List[LORA_ATTN_MODULES]) – 每个自注意力块中应将 LoRA 应用于哪些线性层的列表。选项为
{"q_proj", "k_proj", "v_proj", "output_proj"}
。apply_lora_to_mlp (bool) – 是否将 LoRA 应用于每个 MLP 线性层。
apply_lora_to_output (bool) – 是否将 LoRA 应用于最终输出投影。
full_model_state_dict_keys (List[str]) – 完整模型状态字典中的键列表。
lora_state_dict_keys (Optional[List[str]]) – LoRA 状态字典中的键列表。如果为 None,则不会验证 LoRA 状态字典键。
base_model_state_dict_keys (Optional[List[str]]) – 基础模型状态字典中的键列表。如果为 None,则不会验证基础模型键。
- 返回值:
无
- 引发:
AssertionError – 如果基础模型状态字典缺少完整模型中的任何非 LoRA 参数。
AssertionError – 如果 LoRA 状态字典缺少完整模型中的任何 LoRA 参数。
AssertionError – 如果基础模型状态字典包含任何 LoRA 参数。
AssertionError – 如果 LoRA 状态字典包含任何非 LoRA 参数。
AssertionError – 如果基础模型和 LoRA 状态字典具有重叠的键。
AssertionError – 如果完整模型状态字典缺少基础模型或 LoRA 状态字典中的任何键。