快捷方式

FullModelTorchTuneCheckpointer

class torchtune.training.FullModelTorchTuneCheckpointer(checkpoint_dir: str, checkpoint_files: List[str], model_type: ModelType, output_dir: str, adapter_checkpoint: Optional[str] = None, recipe_checkpoint: Optional[str] = None, resume_from_checkpoint: bool = False)[source]

检查点管理器,以与 torchtune 兼容的格式读取和写入检查点。不需要转换权重。

目前,此管理器仅支持读取单个检查点文件。随着我们对更大模型的支持,这可能会发生变化。

参数:
  • checkpoint_dir (str) – 包含检查点文件的目录

  • checkpoint_files (List[str]) – 要加载的检查点文件列表。由于检查点管理器会负责根据文件 ID 进行排序,因此此列表中的顺序无关紧要

  • model_type (ModelType) – 加载检查点管理器的模型的模型类型

  • output_dir (str) – 保存检查点文件的目录

  • adapter_checkpoint (Optional[str]) – 适配器权重的路径。默认为 None

  • recipe_checkpoint (Optional[str]) – 食谱状态检查点文件的路径。默认为 None

  • resume_from_checkpoint (bool) – 如果为 True,则检查点管理器将加载其他检查点文件,以从以前的运行中恢复训练。默认为 False

引发:
  • ValueError – 如果提供了多个检查点文件

  • ValueError – 如果检查点文件没有 .pt 扩展名

  • ValueError – 如果 resume_from_checkpoint 为 True 但 recipe_checkpoint 为 None

load_checkpoint(weights_only: bool = True) Dict[str, Any][source]

从文件中加载 torchtune 检查点。目前仅支持从单个文件加载。

输出 state_dict 具有以下格式,其中除“model”之外的键仅在 resume_from_checkpoint 为 True 时存在

>>>     {
>>>         "model": {
>>>             "key_1": weight
>>>             ...
>>>         },
>>>         "optimizer": {...},
>>>         ...
>>>     }
参数:

weights_only (bool) – 传递给 torch.load 的标志。我们公开此标志,因为量化模型无法使用 weights_only=True 加载

返回值:

输入检查点的 state_dict

返回类型:

Dict[str, Any]

保存检查点(state_dict: Dict[str, Any], epoch: int, intermediate_checkpoint: bool = False, adapter_only: bool = False) None[source]

将 torchtune 检查点保存到文件。如果 intermediate_checkpoint 为 True,则会在 _output_dir 中创建一个额外的检查点文件 recipe_state.pt,其中包含配方状态。输出状态字典具有以下格式

>>> # Model
>>> {
>>>     "key_1": weight
>>>     ...
>>> }
>>>
>>> # Recipe state
>>> {
>>>     "optimizer": ...,
>>>     "epoch": ...,
>>>     ...
>>> }
参数:
  • state_dict (Dict[str, Any]) – 包含模型和(可选)配方状态的状态字典

  • epoch (int) – 当前纪元数。这将添加到检查点文件名中,以确保我们不会覆盖中间检查点文件

  • intermediate_checkpoint (bool) – 如果为 True,则保存一个包含配方状态的额外检查点文件

  • adapter_only (bool) – 如果为 True,则仅保存适配器权重。默认为 False

引发:

ValueError – 如果 adapter_only 为 True 且在 state_dict 中未找到适配器检查点。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源