FullModelTorchTuneCheckpointer¶
- class torchtune.training.FullModelTorchTuneCheckpointer(checkpoint_dir: str, checkpoint_files: List[str], model_type: ModelType, output_dir: str, adapter_checkpoint: Optional[str] = None, recipe_checkpoint: Optional[str] = None, resume_from_checkpoint: bool = False)[source]¶
检查点管理器,以与 torchtune 兼容的格式读取和写入检查点。不需要转换权重。
目前,此管理器仅支持读取单个检查点文件。随着我们对更大模型的支持,这可能会发生变化。
- 参数:
checkpoint_dir (str) – 包含检查点文件的目录
checkpoint_files (List[str]) – 要加载的检查点文件列表。由于检查点管理器会负责根据文件 ID 进行排序,因此此列表中的顺序无关紧要
model_type (ModelType) – 加载检查点管理器的模型的模型类型
output_dir (str) – 保存检查点文件的目录
adapter_checkpoint (Optional[str]) – 适配器权重的路径。默认为 None
recipe_checkpoint (Optional[str]) – 食谱状态检查点文件的路径。默认为 None
resume_from_checkpoint (bool) – 如果为 True,则检查点管理器将加载其他检查点文件,以从以前的运行中恢复训练。默认为 False
- 引发:
ValueError – 如果提供了多个检查点文件
ValueError – 如果检查点文件没有 .pt 扩展名
ValueError – 如果
resume_from_checkpoint
为 True 但recipe_checkpoint
为 None
- load_checkpoint(weights_only: bool = True) Dict[str, Any] [source]¶
从文件中加载 torchtune 检查点。目前仅支持从单个文件加载。
输出 state_dict 具有以下格式,其中除“model”之外的键仅在
resume_from_checkpoint
为 True 时存在>>> { >>> "model": { >>> "key_1": weight >>> ... >>> }, >>> "optimizer": {...}, >>> ... >>> }
- 保存检查点(state_dict: Dict[str, Any], epoch: int, intermediate_checkpoint: bool = False, adapter_only: bool = False) None [source]¶
将 torchtune 检查点保存到文件。如果
intermediate_checkpoint
为 True,则会在_output_dir
中创建一个额外的检查点文件recipe_state.pt
,其中包含配方状态。输出状态字典具有以下格式>>> # Model >>> { >>> "key_1": weight >>> ... >>> } >>> >>> # Recipe state >>> { >>> "optimizer": ..., >>> "epoch": ..., >>> ... >>> }
- 参数:
- 引发:
ValueError – 如果
adapter_only
为 True 且在 state_dict 中未找到适配器检查点。