torchtune 中的检查点¶
本深入探讨将带您了解检查点和相关实用程序的设计和行为。
torchtune 的检查点设计
检查点格式以及我们如何处理它们
检查点场景:中间与最终以及 LoRA 与完整微调
概述¶
torchtune 检查点旨在成为可组合的组件,可以插入到任何食谱中 - 训练、评估或生成。每个检查点都支持一组模型和场景,使其易于理解、调试和扩展。
在我们深入研究 torchtune 中的检查点之前,让我们定义一些概念。
检查点格式¶
在本深入探讨中,我们将讨论不同的检查点格式以及 torchtune 如何处理它们。让我们仔细看看这些不同的格式。
简单地说,检查点的格式由 state_dict 决定,以及如何在磁盘上的文件中存储此 state_dict。每个权重都与一个字符串键相关联,该字符串键在 state_dict 中标识它。如果存储的检查点中键的字符串标识符与模型定义中的字符串标识符不完全匹配,您要么遇到明确的错误(加载 state_dict 将引发异常),要么遇到更糟的情况 - 静默错误(加载将成功,但训练或推理将无法按预期工作)。除了键对齐之外,您还需要权重(state_dict 中的值)的形状与模型定义预期的形状完全匹配。
让我们看看 Llama2 的两种流行格式。
Meta 格式
这是官方 Llama2 实现支持的格式。当您从 meta-llama 网站 下载 Llama2 7B 模型时,您将可以访问单个 .pth
检查点文件。您可以使用 torch.load
轻松检查此检查点的内容
>>> import torch
>>> state_dict = torch.load('consolidated.00.pth', mmap=True, weights_only=True, map_location='cpu')
>>> # inspect the keys and the shapes of the associated tensors
>>> for key, value in state_dict.items():
>>> print(f'{key}: {value.shape}')
tok_embeddings.weight: torch.Size([32000, 4096])
...
...
>>> print(len(state_dict.keys()))
292
state_dict 包含 292 个键,包括一个名为 tok_embeddings
的输入嵌入表。此 state_dict 的模型定义期望一个具有 32000
个令牌的嵌入层,每个令牌都有一个维度为 4096
的嵌入。
HF 格式
这是 Hugging Face 模型中心中最流行的格式,也是每个 torchtune 配置中的默认格式。这也是您从 Llama-2-7b-hf 仓库下载 llama2 模型时获得的格式。
第一个主要区别是 state_dict 分散在两个 .bin
文件中。要正确加载检查点,您需要将这些文件拼凑起来。让我们检查其中一个文件。
>>> import torch
>>> state_dict = torch.load('pytorch_model-00001-of-00002.bin', mmap=True, weights_only=True, map_location='cpu')
>>> # inspect the keys and the shapes of the associated tensors
>>> for key, value in state_dict.items():
>>> print(f'{key}: {value.shape}')
model.embed_tokens.weight: torch.Size([32000, 4096])
...
...
>>> print(len(state_dict.keys()))
241
state_dict 不仅包含更少的键(预期如此,因为这是两个文件之一),而且嵌入表被称为 model.embed_tokens
而不是 tok_embeddings
。名称的这种不匹配会导致您尝试加载 state_dict 时出现异常。这两个层的大小是相同的,这是预期的。
如您所见,如果您不小心,您可能会在检查点加载和保存过程中遇到许多错误。torchtune 检查点通过为您管理 state_dict 使得此过程不易出错。torchtune 旨在“对 state_dict 不变”。
加载时,torchtune 接受来自多个来源的多格式检查点。您不必担心在每次运行食谱时都显式转换检查点。
保存时,torchtune 会以与源相同的格式生成检查点。这包括将 state_dict 转换回原始形式并将键和权重分散在相同数量的文件中。
成为“对 state_dict 不变”的一个主要优势是,您应该能够将 torchtune 中的微调检查点与任何支持源格式的后训练工具(量化、评估、推理)一起使用,而无需任何代码更改或转换脚本。这是 torchtune 与周围生态系统交互的方式之一。
注意
要以这种方式实现对 state_dict“不变”,每个检查点的 load_checkpoint
和 save_checkpoint
方法使用权重转换器,这些转换器可以正确地将权重映射到检查点格式之间。例如,当从 Hugging Face 加载权重时,我们会在加载和保存时对某些权重应用排列,以确保检查点行为完全相同。为了进一步说明这一点,Llama 系列模型使用 通用的权重转换器函数,而某些其他模型(如 Phi3)则有自己的 转换函数,可以在其模型文件夹中找到。
处理不同的检查点格式¶
torchtune 支持三种不同的 检查点,每种检查点都支持不同的检查点格式。
HFCheckpointer
¶
此检查点以与 Hugging Face 的 transformers 框架兼容的格式读取和写入检查点。如上所述,这是 Hugging Face 模型中心中最流行的格式,也是每个 torchtune 配置中的默认格式。
为了使此检查点正常工作,我们假设 checkpoint_dir
包含必要的检查点和 json 文件。确保一切正常工作的最简单方法是使用以下流程
使用 tune download 从 HF 仓库下载模型。默认情况下,这将忽略“safetensors”文件。
tune download meta-llama/Llama-2-7b-hf \ --output-dir <checkpoint_dir> \ --hf-token <hf-token>
将此处指定的
output_dir
用作检查点的checkpoint_dir
参数。
以下代码片段说明了如何在 torchtune 配置文件中设置 HFCheckpointer。
checkpointer:
# checkpointer to use
_component_: torchtune.training.FullModelHFCheckpointer
# directory with the checkpoint files
# this should match the output_dir above
checkpoint_dir: <checkpoint_dir>
# checkpoint files. For the llama2-7b-hf model we have
# 2 .bin files. The checkpointer takes care of sorting
# by id and so the order here does not matter
checkpoint_files: [
pytorch_model-00001-of-00002.bin,
pytorch_model-00002-of-00002.bin,
]
# if we're restarting a previous run, we need to specify
# the file with the checkpoint state. More on this in the
# next section
recipe_checkpoint: null
# dir for saving the output checkpoints. Usually set
# to be the same as checkpoint_dir
output_dir: <checkpoint_dir>
# model_type which specifies how to convert the state_dict
# into a format which torchtune understands
model_type: LLAMA2
# set to True if restarting training
resume_from_checkpoint: False
注意
将检查点转换为 Hugging Face 格式或从 Hugging Face 格式转换需要访问模型参数,这些参数直接从 config.json
文件中读取。这有助于确保我们以正确的方式加载权重,或者在 Hugging Face 检查点文件与 torchtune 的模型实现不一致的情况下,引发错误。此 json 文件与模型检查点一起从 hub 下载。
MetaCheckpointer
¶
此检查点读取器和写入器以与原始 meta-llama github 存储库兼容的格式进行检查点操作。
为了使此检查点正常工作,我们假设 checkpoint_dir
包含必要的检查点和 json 文件。确保一切正常工作的最简单方法是使用以下流程
使用 tune download 从 HF 仓库下载模型。默认情况下,这将忽略“safetensors”文件。
tune download meta-llama/Llama-2-7b \ --output-dir <checkpoint_dir> \ --hf-token <hf-token>
将上面的
output_dir
用作检查点读取器的checkpoint_dir
。
以下代码片段说明了如何在 torchtune 配置文件中设置 MetaCheckpointer。
checkpointer:
# checkpointer to use
_component_: torchtune.training.FullModelMetaCheckpointer
# directory with the checkpoint files
# this should match the output_dir above
checkpoint_dir: <checkpoint_dir>
# checkpoint files. For the llama2-7b model we have
# a single .pth file
checkpoint_files: [consolidated.00.pth]
# if we're restarting a previous run, we need to specify
# the file with the checkpoint state. More on this in the
# next section
recipe_checkpoint: null
# dir for saving the output checkpoints. Usually set
# to be the same as checkpoint_dir
output_dir: <checkpoint_dir>
# model_type which specifies how to convert the state_dict
# into a format which torchtune understands
model_type: LLAMA2
# set to True if restarting training
resume_from_checkpoint: False
TorchTuneCheckpointer
¶
此检查点读取器和写入器以与 torchtune 的模型定义兼容的格式进行检查点操作。它不执行任何 state_dict 转换,目前用于测试或加载量化模型以进行生成。
中间检查点与最终检查点¶
torchtune 检查点读取器支持两种检查点方案
训练结束时的检查点
已完成训练运行结束时的模型权重将写入文件。检查点读取器确保输出检查点文件与用于开始训练的输入检查点文件具有相同的键。检查点读取器还确保这些键跨越与原始检查点相同数量的文件进行分区。输出 state dict 具有以下标准格式
{ "key_1": weight_1, "key_2": weight_2, ... }
训练过程中的检查点.
如果在训练过程中进行检查点操作,则输出检查点需要存储额外的信息,以确保后续的训练运行能够正确地重新启动。除了模型检查点文件之外,我们还会为中间检查点输出一个 recipe_state.pt
文件。这些文件目前在每个 epoch 结束时输出,并包含诸如优化器状态、已完成的 epoch 数量等信息。
为了防止我们用检查点文件淹没 output_dir
,配方状态在每个 epoch 结束时会被覆盖。
输出 state dict 具有以下格式
Model: { "key_1": weight_1, "key_2": weight_2, ... } Recipe State: { "optimizer": ..., "epoch": ..., ... }
要从之前的检查点文件重新开始,您需要对配置文件进行以下更改
checkpointer:
# checkpointer to use
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: <checkpoint_dir>
# checkpoint files. Note that you will need to update this
# section of the config with the intermediate checkpoint files
checkpoint_files: [
hf_model_0001_0.pt,
hf_model_0002_0.pt,
]
# if we're restarting a previous run, we need to specify
# the file with the checkpoint state
recipe_checkpoint: recipe_state.pt
# dir for saving the output checkpoints. Usually set
# to be the same as checkpoint_dir
output_dir: <checkpoint_dir>
# model_type which specifies how to convert the state_dict
# into a format which torchtune understands
model_type: LLAMA2
# set to True if restarting training
resume_from_checkpoint: True
LoRA 的检查点操作¶
在 torchtune 中,我们为 LoRA 输出适配器权重和完整的“合并”模型权重。“合并”检查点可以像使用任何训练后工具使用源检查点一样使用。有关更多详细信息,请查看我们的 LoRA 微调教程。此外,通过在保存检查点时将选项“save_adapter_weights_only”设置为 True,您可以选择仅保存适配器权重。
这两种用例之间的主要区别在于,当您想要从检查点恢复训练时。在这种情况下,检查点读取器需要访问初始冻结的基模型权重以及学习到的适配器权重。此场景的配置如下所示
checkpointer:
# checkpointer to use
_component_: torchtune.training.FullModelHFCheckpointer
# directory with the checkpoint files
# this should match the output_dir above
checkpoint_dir: <checkpoint_dir>
# checkpoint files. This is the ORIGINAL frozen checkpoint
# and NOT the merged checkpoint output during training
checkpoint_files: [
pytorch_model-00001-of-00002.bin,
pytorch_model-00002-of-00002.bin,
]
# this refers to the adapter weights learnt during training
adapter_checkpoint: adapter_0.pt
# the file with the checkpoint state
recipe_checkpoint: recipe_state.pt
# dir for saving the output checkpoints. Usually set
# to be the same as checkpoint_dir
output_dir: <checkpoint_dir>
# model_type which specifies how to convert the state_dict
# into a format which torchtune understands
model_type: LLAMA2
# set to True if restarting training
resume_from_checkpoint: True
# Set to True to save only the adapter weights
save_adapter_weights_only: False
将所有内容整合在一起¶
现在让我们将所有这些知识整合在一起!我们将加载一些检查点,创建一些模型并运行一个简单的正向传播。
在本节中,我们将使用 Hugging Face 格式的 Llama2 13B 模型。
import torch
from torchtune.training import FullModelHFCheckpointer, ModelType
from torchtune.models.llama2 import llama2_13b
# Set the right directory and files
checkpoint_dir = 'Llama-2-13b-hf/'
pytorch_files = [
'pytorch_model-00001-of-00003.bin',
'pytorch_model-00002-of-00003.bin',
'pytorch_model-00003-of-00003.bin'
]
# Set up the checkpointer and load state dict
checkpointer = FullModelHFCheckpointer(
checkpoint_dir=checkpoint_dir,
checkpoint_files=pytorch_files,
output_dir=checkpoint_dir,
model_type=ModelType.LLAMA2
)
torchtune_sd = checkpointer.load_checkpoint()
# Setup the model and the input
model = llama2_13b()
# Model weights are stored with the key="model"
model.load_state_dict(torchtune_sd["model"])
<All keys matched successfully>
# We have 32000 vocab tokens; lets generate an input with 70 tokens
x = torch.randint(0, 32000, (1, 70))
with torch.no_grad():
model(x)
tensor([[[ -6.3989, -9.0531, 3.2375, ..., -5.2822, -4.4872, -5.7469],
[ -8.6737, -11.0023, 6.8235, ..., -2.6819, -4.2424, -4.0109],
[ -4.6915, -7.3618, 4.1628, ..., -2.8594, -2.5857, -3.1151],
...,
[ -7.7808, -8.2322, 2.8850, ..., -1.9604, -4.7624, -1.6040],
[ -7.3159, -8.5849, 1.8039, ..., -0.9322, -5.2010, -1.6824],
[ -7.8929, -8.8465, 3.3794, ..., -1.3500, -4.6145, -2.5931]]])
您可以使用 torchtune 支持的任何模型来执行此操作。您可以在 此处 找到模型和模型构建器的完整列表。
我们希望这次深入研究能够让您更深入地了解 torchtune 中的检查点读取器和相关实用程序。祝您调优愉快!