• 文档 >
  • torchtune 中的检查点
快捷方式

torchtune 中的检查点

本深入探讨将带您了解检查点和相关实用程序的设计和行为。

本深入探讨将涵盖的内容
  • torchtune 检查点设计

  • 检查点格式以及我们如何处理它们

  • 检查点场景:中间 vs 最终和 LoRA vs 全微调

概述

torchtune 检查点旨在成为可组合的组件,可以插入任何食谱 - 训练、评估或生成。每个检查点都支持一组模型和场景,使这些模型易于理解、调试和扩展。

在我们深入了解 torchtune 中的检查点之前,让我们定义一些概念。


检查点格式

在本深入探讨中,我们将讨论不同的检查点格式以及 torchtune 如何处理它们。让我们仔细看看这些不同的格式。

简单地说,检查点的格式由 state_dict 和它如何在磁盘上的文件中存储的方式决定。每个权重都与一个字符串键关联,该键在 state dict 中识别它。如果存储的检查点中的键的字符串标识符与模型定义中的字符串标识符不完全匹配,您将遇到显式错误(加载 state dict 将引发异常)或更糟糕的 - 静默错误(加载将成功,但训练或推理将无法按预期工作)。除了键对齐外,您还需要权重(state_dict 中的值)的形状与模型定义预期的形状完全匹配。

让我们看看 Llama2 的两种流行格式。

Meta 格式

这是官方 Llama2 实现支持的格式。当您从 meta-llama 网站 下载 Llama2 7B 模型时,您将获得一个 .pth 检查点文件的访问权限。您可以使用 torch.load 轻松检查此检查点的内容

>>> import torch
>>> state_dict = torch.load('consolidated.00.pth', mmap=True, weights_only=True, map_location='cpu')
>>> # inspect the keys and the shapes of the associated tensors
>>> for key, value in state_dict.items():
>>>    print(f'{key}: {value.shape}')

tok_embeddings.weight: torch.Size([32000, 4096])
...
...
>>> print(len(state_dict.keys()))
292

state_dict 包含 292 个键,包括一个名为 tok_embeddings 的输入嵌入表。此 state_dict 的模型定义期望一个具有 32000 个标记的嵌入层,每个标记的嵌入维度为 4096

HF 格式

这是 Hugging Face 模型中心中最流行的格式,也是每个 torchtune 配置的默认格式。这也是您从 Llama-2-7b-hf 存储库下载 llama2 模型时获得的格式。

第一个主要区别是 state_dict 分布在两个 .bin 文件中。要正确加载检查点,您需要将这些文件拼凑在一起。让我们检查其中一个文件。

>>> import torch
>>> state_dict = torch.load('pytorch_model-00001-of-00002.bin', mmap=True, weights_only=True, map_location='cpu')
>>> # inspect the keys and the shapes of the associated tensors
>>> for key, value in state_dict.items():
>>>     print(f'{key}: {value.shape}')

model.embed_tokens.weight: torch.Size([32000, 4096])
...
...
>>> print(len(state_dict.keys()))
241

state_dict 不仅包含更少的键(因为这是两个文件中的一个),而且嵌入表被称为 model.embed_tokens 而不是 tok_embeddings。这种名称不匹配会导致您尝试加载 state_dict 时出现异常。该层的尺寸在两者之间相同,这是预期的。


如您所见,如果您不小心,您可能会在检查点加载和保存过程中遇到很多错误。torchtune 检查点通过为您管理 state dict 使此过程不易出错。torchtune 旨在“state-dict 不变”。

  • 加载时,torchtune 接受来自多种来源的多种格式的检查点。您不必担心在每次运行食谱时显式转换检查点。

  • 保存时,torchtune 以与源相同的格式生成检查点。这包括将 state_dict 转换回原始形式,并将键和权重分布在相同数量的文件中。

成为“state-dict 不变”的一大优势是,您应该能够在无需任何代码更改或转换脚本的情况下,将来自 torchtune 的微调检查点与支持源格式的任何后训练工具(量化、评估、推理)一起使用。这是 torchtune 与周围生态系统交互的方式之一。

为了成为“state-dict 不变”,load_checkpointsave_checkpoint 方法利用了可用的权重转换器 此处


处理不同的检查点格式

torchtune 支持三种不同的 检查点,每个检查点都支持不同的检查点格式。

HFCheckpointer

此检查点以与 Hugging Face 的 transformers 框架兼容的格式读取和写入检查点。如上所述,这是 Hugging Face 模型中心中最流行的格式,也是每个 torchtune 配置的默认格式。

为了使此检查点正常工作,我们假设 checkpoint_dir 包含必要的检查点和 json 文件。确保一切正常工作的最简单方法是使用以下流程

  • 使用 tune download 从 HF 存储库下载模型。默认情况下,这将忽略“safetensors”文件。


    tune download meta-llama/Llama-2-7b-hf \
    --output-dir <checkpoint_dir> \
    --hf-token <hf-token>
    
  • 将此处指定的 output_dir 用作检查点的 checkpoint_dir 参数。


以下代码段说明了如何在 torchtune 配置文件中设置 HFCheckpointer。

checkpointer:

    # checkpointer to use
    _component_: torchtune.utils.FullModelHFCheckpointer

    # directory with the checkpoint files
    # this should match the output_dir above
    checkpoint_dir: <checkpoint_dir>

    # checkpoint files. For the llama2-7b-hf model we have
    # 2 .bin files. The checkpointer takes care of sorting
    # by id and so the order here does not matter
    checkpoint_files: [
        pytorch_model-00001-of-00002.bin,
        pytorch_model-00002-of-00002.bin,
    ]

    # if we're restarting a previous run, we need to specify
    # the file with the checkpoint state. More on this in the
    # next section
    recipe_checkpoint: null

    # dir for saving the output checkpoints. Usually set
    # to be the same as checkpoint_dir
    output_dir: <checkpoint_dir>

    # model_type which specifies how to convert the state_dict
    # into a format which torchtune understands
    model_type: LLAMA2

# set to True if restarting training
resume_from_checkpoint: False

注意

检查点转换为 HF 格式以及从 HF 格式转换检查点需要访问直接从 config.json 文件读取的模型参数。这有助于确保我们正确加载权重,或者在 HF 检查点文件与 torchtune 的模型实现之间存在差异时出错。该 json 文件与模型检查点一起从中心下载。有关它们在转换过程中如何使用的更多详细信息,请参见 此处


MetaCheckpointer

此检查点以与原始 meta-llama github 存储库兼容的格式读取和写入检查点。

为了使此检查点正常工作,我们假设 checkpoint_dir 包含必要的检查点和 json 文件。确保一切正常工作的最简单方法是使用以下流程

  • 使用 tune download 从 HF 存储库下载模型。默认情况下,这将忽略“safetensors”文件。


    tune download meta-llama/Llama-2-7b \
    --output-dir <checkpoint_dir> \
    --hf-token <hf-token>
    
  • 将上面的 output_dir 用作检查点的 checkpoint_dir


以下代码段说明了如何在 torchtune 配置文件中设置 MetaCheckpointer。

checkpointer:

    # checkpointer to use
    _component_: torchtune.utils.FullModelMetaCheckpointer

    # directory with the checkpoint files
    # this should match the output_dir above
    checkpoint_dir: <checkpoint_dir>

    # checkpoint files. For the llama2-7b model we have
    # a single .pth file
    checkpoint_files: [consolidated.00.pth]

    # if we're restarting a previous run, we need to specify
    # the file with the checkpoint state. More on this in the
    # next section
    recipe_checkpoint: null

    # dir for saving the output checkpoints. Usually set
    # to be the same as checkpoint_dir
    output_dir: <checkpoint_dir>

    # model_type which specifies how to convert the state_dict
    # into a format which torchtune understands
    model_type: LLAMA2

# set to True if restarting training
resume_from_checkpoint: False

TorchTuneCheckpointer

此检查点以与 torchtune 的模型定义兼容的格式读取和写入检查点。它不会执行任何 state_dict 转换,目前用于测试或加载量化模型以进行生成。


中间 vs 最终检查点

torchtune 检查点支持两种检查点场景

训练结束检查点

完成的训练运行结束时的模型权重将写入文件。检查点确保输出检查点文件具有与用于开始训练的输入检查点文件相同的键。检查点还确保键在与原始检查点相同数量的文件中进行划分。输出 state dict 具有以下标准格式

{
    "key_1": weight_1,
    "key_2": weight_2,
    ...
}

训练中期检查点.

如果在训练过程中进行检查点,则输出检查点需要存储其他信息以确保后续的训练运行可以正确重启。除了模型检查点文件之外,我们还为中间检查点输出 recipe_state.pt 文件。这些目前在每个纪元结束时输出,并包含有关优化器状态、已完成的纪元数量等信息。

为了防止output_dir目录被检查点文件填满,每个 epoch 结束时都会覆盖配方状态。

输出状态字典具有以下格式

Model:
    {
        "key_1": weight_1,
        "key_2": weight_2,
        ...
    }

Recipe State:
    {
        "optimizer": ...,
        "epoch": ...,
        ...
    }

要从之前的检查点文件重新开始,您需要对配置文件进行以下更改

checkpointer:

    # checkpointer to use
    _component_: torchtune.utils.FullModelHFCheckpointer

    checkpoint_dir: <checkpoint_dir>

    # checkpoint files. Note that you will need to update this
    # section of the config with the intermediate checkpoint files
    checkpoint_files: [
        hf_model_0001_0.pt,
        hf_model_0002_0.pt,
    ]

    # if we're restarting a previous run, we need to specify
    # the file with the checkpoint state
    recipe_checkpoint: recipe_state.pt

    # dir for saving the output checkpoints. Usually set
    # to be the same as checkpoint_dir
    output_dir: <checkpoint_dir>

    # model_type which specifies how to convert the state_dict
    # into a format which torchtune understands
    model_type: LLAMA2

# set to True if restarting training
resume_from_checkpoint: True

LoRA 的检查点

在 torchtune 中,我们输出适配器权重和 LoRA 的完整模型“合并”权重。该“合并”检查点可以像使用任何训练后工具使用源检查点一样使用。有关更多详细信息,请查看我们的 LoRA 微调教程

这两种用例之间的主要区别在于,当您想要从检查点恢复训练时。在这种情况下,检查点程序需要访问初始冻结的基本模型权重以及学习到的适配器权重。此场景的配置如下所示

checkpointer:

    # checkpointer to use
    _component_: torchtune.utils.FullModelHFCheckpointer

    # directory with the checkpoint files
    # this should match the output_dir above
    checkpoint_dir: <checkpoint_dir>

    # checkpoint files. This is the ORIGINAL frozen checkpoint
    # and NOT the merged checkpoint output during training
    checkpoint_files: [
        pytorch_model-00001-of-00002.bin,
        pytorch_model-00002-of-00002.bin,
    ]

    # this refers to the adapter weights learnt during training
    adapter_checkpoint: adapter_0.pt

    # the file with the checkpoint state
    recipe_checkpoint: recipe_state.pt

    # dir for saving the output checkpoints. Usually set
    # to be the same as checkpoint_dir
    output_dir: <checkpoint_dir>

    # model_type which specifies how to convert the state_dict
    # into a format which torchtune understands
    model_type: LLAMA2

# set to True if restarting training
resume_from_checkpoint: True

综合应用

现在我们将所有这些知识整合在一起!我们将加载一些检查点,创建一些模型并运行一个简单的正向传播。

在本节中,我们将使用 HF 格式的 Llama2 13B 模型。

import torch
from torchtune.utils import FullModelHFCheckpointer, ModelType
from torchtune.models.llama2 import llama2_13b

# Set the right directory and files
checkpoint_dir = 'Llama-2-13b-hf/'
pytorch_files = [
    'pytorch_model-00001-of-00003.bin',
    'pytorch_model-00002-of-00003.bin',
    'pytorch_model-00003-of-00003.bin'
]

# Set up the checkpointer and load state dict
checkpointer = FullModelHFCheckpointer(
    checkpoint_dir=checkpoint_dir,
    checkpoint_files=pytorch_files,
    output_dir=checkpoint_dir,
    model_type=ModelType.LLAMA2
)
torchtune_sd = checkpointer.load_checkpoint()

# Setup the model and the input
model = llama2_13b()

# Model weights are stored with the key="model"
model.load_state_dict(torchtune_sd["model"])
<All keys matched successfully>

# We have 32000 vocab tokens; lets generate an input with 70 tokens
x = torch.randint(0, 32000, (1, 70))

with torch.no_grad():
    model(x)

tensor([[[ -6.3989,  -9.0531,   3.2375,  ...,  -5.2822,  -4.4872,  -5.7469],
    [ -8.6737, -11.0023,   6.8235,  ...,  -2.6819,  -4.2424,  -4.0109],
    [ -4.6915,  -7.3618,   4.1628,  ...,  -2.8594,  -2.5857,  -3.1151],
    ...,
    [ -7.7808,  -8.2322,   2.8850,  ...,  -1.9604,  -4.7624,  -1.6040],
    [ -7.3159,  -8.5849,   1.8039,  ...,  -0.9322,  -5.2010,  -1.6824],
    [ -7.8929,  -8.8465,   3.3794,  ...,  -1.3500,  -4.6145,  -2.5931]]])

您可以对 torchtune 支持的任何模型执行此操作。您可以在此处找到模型和模型构建器的完整列表 here

我们希望这篇深入探讨能让你对 torchtune 中的检查点程序和相关实用程序有更深入的了解。祝您调参愉快!

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取适合初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源