快捷方式

FullModelHFCheckpointer

class torchtune.training.FullModelHFCheckpointer(checkpoint_dir: str, checkpoint_files: Union[List[str], Dict[str, str]], model_type: str, output_dir: str, adapter_checkpoint: Optional[str] = None, recipe_checkpoint: Optional[str] = None, resume_from_checkpoint: bool = False, safe_serialization: bool = True, should_load_recipe_state: bool = False)[source]

以 HF 格式读写检查点的 Checkpointer。对于 LoRA 模型,这包括以可通过例如 from_pretrained 加载到 PEFT 中的格式保存检查点。示例包括 meta-llama 仓库中的 Llama-2-7b-hf 模型 (https://hugging-face.cn/meta-llama/Llama-2-7b-hf)。

注意

HF 检查点名称通常按 ID 排序(例如:0001_of_0003, 0002_of_0003 等)。为了确保我们以正确的顺序读取文件,我们在读取前对检查点文件名进行排序。

注意

检查点与 HF 格式之间的转换需要访问直接从 config.json 文件读取的模型参数。这有助于确保我们能正确加载权重,或者在 HF 检查点文件与 torchtune 模型实现之间存在差异时报错。

参数:
  • checkpoint_dir (str) – 包含检查点文件的目录

  • checkpoint_files (Union[List[str], Dict[str, str]]) – 要加载的检查点文件列表,或包含键 ["filename_format", "max_filename"] 的字典。由于 checkpointer 负责按文件 ID 排序,因此此列表中的顺序无关紧要。

  • model_type (str) – 正在加载 checkpointer 的模型的模型类型,例如 LLAMA3。

  • output_dir (str) – 保存检查点文件的目录

  • adapter_checkpoint (Optional[str]) – adapter 权重的路径。如果为 None 且 should_load_recipe_state=True,则在 output_dir/epoch_{largest_epoch} 中查找 adapter_model.pt。默认为 None。

  • recipe_checkpoint (Optional[str]) – recipe 状态检查点文件的路径。如果为 None 且 should_load_recipe_state=True,则在 output_dir/RECIPE_STATE_DIRNAME 中查找 recipe_state.pt。默认为 None。

  • resume_from_checkpoint (bool) – 如果为 True,checkpointer 将从之前运行中加载与 recipe 状态对应的额外检查点文件。默认为 False。此标志已弃用。请改用 should_load_recipe_state 标志。

  • safe_serialization (bool) – 如果为 True,checkpointer 将使用 safetensors 保存检查点文件。默认为 True。

  • should_load_recipe_state (bool) – 如果为 True,checkpointer 将从之前运行中加载与 recipe 状态对应的额外检查点文件。默认为 False。

load_checkpoint() Dict[str, Any][source]

从文件中加载 HF 检查点。

所有检查点文件中的键和权重会合并到一个 state_dict 中。我们在 weight_map 中保留“state_dict key” <-> “checkpoint file” 映射,以便在 save_checkpoint 中正确写入 state dict。

返回之前,模型 state dict 会使用适当的 convert_weights 函数(取决于 self._model_type)转换为 torchtune 兼容的格式。

返回:

torchtune 检查点状态字典

返回类型:

state_dict (Dict[str, Any])

抛出:

ValueError – 如果输入 state_dict 中的值不是 Tensors

save_checkpoint(state_dict: Dict[str, Any], epoch: int, intermediate_checkpoint: bool = False, adapter_only: bool = False) None[source]

将 HF 检查点保存到文件。如果 intermediate_checkpoint 为 True,则会在 _output_dir/RECIPE_STATE_DIRNAME 中创建一个额外的检查点文件 recipe_state.pt,其中包含 recipe 状态。

state_dict 会首先转换回 HF 格式,然后根据 _weight_map 划分到单独的检查点文件中。

参数:
  • state_dict (Dict[str, Any]) – 要写入文件的检查点状态字典

  • epoch (int) – epoch 编号。用于创建检查点文件名

  • intermediate_checkpoint (bool) – 如果为 True,将创建额外的检查点文件用于保存 recipe 状态和(如果适用)adapter 权重。默认为 False

  • adapter_only (bool) – 如果为 True,则仅保存 adapter 权重。默认为 False

抛出:

ValueError – 如果 adapter_only 为 True 且 state_dict 中找不到 adapter 检查点。

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得问题解答

查看资源