FullModelHFCheckpointer¶
- class torchtune.training.FullModelHFCheckpointer(checkpoint_dir: str, checkpoint_files: Union[List[str], Dict[str, str]], model_type: str, output_dir: str, adapter_checkpoint: Optional[str] = None, recipe_checkpoint: Optional[str] = None, resume_from_checkpoint: bool = False, safe_serialization: bool = True)[源代码]¶
检查点器,用于读取和写入 HF 格式的检查点。对于 LoRA 模型,这包括以可通过 PEFT 加载的格式保存检查点,例如通过
from_pretrained
。示例包括来自 meta-llama 仓库的 Llama-2-7b-hf 模型 (https://hugging-face.cn/meta-llama/Llama-2-7b-hf)。注意
HF 检查点名称通常按 ID 排序(例如:0001_of_0003、0002_of_0003 等)。为确保我们以正确的顺序读取文件,我们在读取之前对检查点文件名进行排序。
注意
与 HF 格式之间的检查点转换需要访问模型参数,这些参数直接从
config.json
文件中读取。这有助于确保我们正确加载权重,或者在 HF 检查点文件与 torchtune 的模型实现之间存在差异时报错。- 参数:
checkpoint_dir (str) – 包含检查点文件的目录
checkpoint_files (Union[List[str], Dict[str, str]]) – 要加载的检查点文件列表或包含键 [“filename_format”, “max_filename”] 的字典。由于检查点器负责按文件 ID 排序,因此此列表中的顺序无关紧要。
model_type (str) – 要加载检查点器的模型的模型类型,例如 LLAMA3。
output_dir (str) – 用于保存检查点文件的目录
adapter_checkpoint (Optional[str]) – 适配器权重的路径。如果为 None 且 resume_from_checkpoint=True,则在 output_dir/epoch_{largest_epoch} 中查找 adapter_model.pt。默认为 None。
recipe_checkpoint (Optional[str]) – 配方状态检查点文件的路径。如果为 None 且 resume_from_checkpoint=True,则在 output_dir/RECIPE_STATE_DIRNAME 中查找 recipe_state.pt。默认为 None。
resume_from_checkpoint (bool) – 如果为 True,则检查点器将加载其他检查点文件以从之前的运行恢复训练。默认为 False
safe_serialization (bool) – 如果为 True,则检查点器将使用 safetensors 保存检查点文件。默认为 True。
- load_checkpoint() Dict[str, Any] [源代码]¶
从文件加载 HF 检查点。
来自所有检查点文件的键和权重合并到一个 state_dict 中。我们在 weight_map 中保留 “state_dict 键” <-> “检查点文件” 映射,以便我们可以在
save_checkpoint
中正确写入状态字典。在返回之前,模型状态字典使用适当的 convert_weights 函数(取决于
self._model_type
)转换为 torchtune 兼容格式。- 返回:
torchtune 检查点状态字典
- 返回类型:
state_dict (Dict[str, Any])
- 引发:
ValueError – 如果输入 state_dict 中的值不是 Tensor
- save_checkpoint(state_dict: Dict[str, Any], epoch: int, intermediate_checkpoint: bool = False, adapter_only: bool = False) None [源代码]¶
将 HF 检查点保存到文件。如果
intermediate_checkpoint
为 True,则在_output_dir/RECIPE_STATE_DIRNAME
中创建一个额外的检查点文件recipe_state.pt
,其中包含配方状态。state_dict 首先转换回 HF 格式,然后根据
_weight_map
分区为单独的检查点文件。- 参数:
- 引发:
ValueError – 如果
adapter_only
为 True 且在 state_dict 中未找到适配器检查点。