使用 torchtune 的端到端工作流程¶
在本教程中,我们将通过一个端到端示例,引导您了解如何使用 torchtune 微调、评估、可选地量化,然后使用您喜欢的 LLM 运行生成。我们还将介绍如何将社区中的一些流行工具和库与 torchtune 无缝集成使用。
torchtune 中除了微调之外可用的不同类型秘籍
连接所有这些秘籍的端到端示例
可与 torchtune 一起使用的不同工具和库
熟悉 torchtune 概览
确保 安装 torchtune
微调您的模型¶
首先,让我们使用 tune 命令行界面下载一个模型。以下命令将从 Hugging Face Hub 下载 Llama3.2 3B Instruct 模型,并将其保存到本地文件系统。Hugging Face 上传了原始权重 (consolidated.00.pth
) 和与 from_pretrained() API 兼容的权重 (*.safetensors
)。我们不需要两者,因此下载时将忽略原始权重。
$ tune download meta-llama/Llama-3.2-3B-Instruct --ignore-patterns "original/consolidated.00.pth"
Successfully downloaded model repo and wrote to the following locations:
/tmp/Llama-3.2-3B-Instruct/.cache
/tmp/Llama-3.2-3B-Instruct/.gitattributes
/tmp/Llama-3.2-3B-Instruct/LICENSE.txt
/tmp/Llama-3.2-3B-Instruct/README.md
/tmp/Llama-3.2-3B-Instruct/USE_POLICY.md
/tmp/Llama-3.2-3B-Instruct/config.json
/tmp/Llama-3.2-3B-Instruct/generation_config.json
/tmp/Llama-3.2-3B-Instruct/model-00001-of-00002.safetensors
...
注意
有关您可以使用 torchtune 开箱即用地微调的所有其他模型的列表,请查看我们的 模型页面。
在本教程中,我们将使用 LoRA 微调模型。LoRA 是一种参数高效的微调技术,当您没有太多 GPU 内存时尤其有用。LoRA 冻结基础 LLM 并添加极少比例的可学习参数。这有助于保持与梯度和优化器状态相关的内存较低。使用 torchtune,您应该能够在 RTX 3090/4090 上使用 bfloat16 在低于 16GB 的 GPU 内存中微调 Llama-3.2-3B-Instruct 模型(使用 LoRA)。有关如何使用 LoRA 的更多信息,请参阅我们的 LoRA 教程。
让我们使用 tune 命令行界面查找适用于此用例的正确配置。
$ tune ls
RECIPE CONFIG
full_finetune_single_device llama2/7B_full_low_memory
code_llama2/7B_full_low_memory
llama3/8B_full_single_device
llama3_1/8B_full_single_device
llama3_2/1B_full_single_device
llama3_2/3B_full_single_device
mistral/7B_full_low_memory
phi3/mini_full_low_memory
qwen2/7B_full_single_device
...
full_finetune_distributed llama2/7B_full
llama2/13B_full
llama3/8B_full
llama3_1/8B_full
llama3_2/1B_full
llama3_2/3B_full
mistral/7B_full
gemma2/9B_full
gemma2/27B_full
phi3/mini_full
qwen2/7B_full
...
lora_finetune_single_device llama2/7B_lora_single_device
llama2/7B_qlora_single_device
llama3/8B_lora_single_device
...
我们将使用我们的 单设备 LoRA 秘籍 进行微调,并使用 默认配置 中的标准设置。
这将使用 batch_size=4
和 dtype=bfloat16
对模型进行微调。在这些设置下,模型峰值内存使用量应约为 16GB,每个 epoch 的总训练时间约为 2-3 小时。
$ tune run lora_finetune_single_device --config llama3_2/3B_lora_single_device
Setting manual seed to local seed 3977464327. Local seed is seed + rank = 3977464327 + 0
Hint: enable_activation_checkpointing is True, but enable_activation_offloading isn't. Enabling activation offloading should reduce memory further.
Writing logs to /tmp/torchtune/llama3_2_3B/lora_single_device/logs/log_1734708879.txt
Model is initialized with precision torch.bfloat16.
Memory stats after model init:
GPU peak memory allocation: 6.21 GiB
GPU peak memory reserved: 6.27 GiB
GPU peak memory active: 6.21 GiB
Tokenizer is initialized from file.
Optimizer and loss are initialized.
Loss is initialized.
Dataset and Sampler are initialized.
Learning rate scheduler is initialized.
Profiling disabled.
Profiler config after instantiation: {'enabled': False}
1|3|Loss: 1.943998098373413: 0%| | 3/1617 [00:21<3:04:47, 6.87s/it]
恭喜您训练好了模型!让我们看看 torchtune 生成的工件。一个简单的方法是运行 tree -a path/to/outputdir
,它应该会显示类似下方树形结构的内容。有 3 种类型的文件夹:
recipe_state: 包含
recipe_state.pt
,其中存储了重新启动上次中间 epoch 所需的信息。更多信息请查看我们的深入探讨 torchtune 中的检查点.;logs: 包含您训练运行的所有日志输出:损失、内存、异常等。
epoch_{}: 包含您训练好的模型权重和模型元数据。如果运行推理或推送到模型中心,您应直接使用此文件夹。
$ tree -a /tmp/torchtune/llama3_2_3B/lora_single_device
/tmp/torchtune/llama3_2_3B/lora_single_device
├── epoch_0
│ ├── adapter_config.json
│ ├── adapter_model.pt
│ ├── adapter_model.safetensors
│ ├── config.json
│ ├── model-00001-of-00002.safetensors
│ ├── model-00002-of-00002.safetensors
│ ├── generation_config.json
│ ├── LICENSE.txt
│ ├── model.safetensors.index.json
│ ├── original
│ │ ├── orig_params.json
│ │ ├── params.json
│ │ └── tokenizer.model
│ ├── original_repo_id.json
│ ├── README.md
│ ├── special_tokens_map.json
│ ├── tokenizer_config.json
│ ├── tokenizer.json
│ └── USE_POLICY.md
├── epoch_1
│ ├── adapter_config.json
│ ...
├── logs
│ └── log_1734652101.txt
└── recipe_state
└── recipe_state.pt
让我们了解这些文件
adapter_model.safetensors
和adapter_model.pt
是您 LoRA 训练的适配器权重。我们保存了一个重复的.pt
版本,以便于从检查点恢复。model-{}-of-{}.safetensors
是您训练好的完整模型权重(非适配器)。在 LoRA 微调时,只有当我们设置save_adapter_weights_only=False
时才会存在这些文件。在这种情况下,我们将基础模型与训练好的适配器合并,从而简化推理。adapter_config.json
由 Huggingface PEFT 在加载适配器时使用(稍后详述);model.safetensors.index.json
由 Hugging Face 的from_pretrained()
在加载模型权重时使用(稍后详述)所有其他文件最初都在
checkpoint_dir
中。它们在训练期间会自动复制。超过 100MiB 且以.safetensors
、.pth
、.pt
、.bin
结尾的文件会被忽略,从而使其轻量化。
评估您的模型¶
我们已经微调了一个模型。但这个模型的表现究竟如何?让我们通过结构化评估和实际使用来确定。
使用 EleutherAI 的 Eval Harness 运行评估¶
torchtune 集成了 EleutherAI 的评估工具包。您可以通过 eleuther_eval 秘籍找到一个示例。在本教程中,我们将通过修改其关联的配置 eleuther_evaluation.yaml 来直接使用此秘籍。
注意
对于本教程的这一部分,您应首先运行 pip install lm_eval>=0.4.5
来安装 EleutherAI 评估工具包。
由于我们计划更新所有检查点文件指向我们微调后的检查点,因此首先将配置文件复制到我们的本地工作目录,以便进行更改。
$ tune cp eleuther_evaluation ./custom_eval_config.yaml
Copied file to custom_eval_config.yaml
请注意,我们使用的是合并后的权重,而不是 LoRA 适配器。
# TODO: update to your desired epoch
output_dir: /tmp/torchtune/llama3_2_3B/lora_single_device/epoch_0
# Tokenizer
tokenizer:
_component_: torchtune.models.llama3.llama3_tokenizer
path: ${output_dir}/original/tokenizer.model
model:
# Notice that we don't pass the lora model. We are using the merged weights,
_component_: torchtune.models.llama3_2.llama3_2_3b
checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: ${output_dir}
checkpoint_files: [
model-00001-of-00002.safetensors,
model-00002-of-00002.safetensors,
]
output_dir: ${output_dir}
model_type: LLAMA3_2
### OTHER PARAMETERS -- NOT RELATED TO THIS CHECKPOINT
# Environment
device: cuda
dtype: bf16
seed: 1234 # It is not recommended to change this seed, b/c it matches EleutherAI's default seed
# EleutherAI specific eval args
tasks: ["truthfulqa_mc2"]
limit: null
max_seq_length: 4096
batch_size: 8
enable_kv_cache: True
# Quantization specific args
quantizer: null
在本教程中,我们将使用该工具包中的 truthfulqa_mc2 任务。
此任务衡量模型在回答问题时的真实性倾向,并衡量模型在问题后接一个或多个真实回应和一个或多个虚假回应时的零样本准确率。
$ tune run eleuther_eval --config ./custom_eval_config.yaml
[evaluator.py:324] Running loglikelihood requests
...
生成一些输出¶
我们已经运行了一些评估,模型似乎表现不错。但它真的能为你在意的提示生成有意义的文本吗?让我们来找出答案!
为此,我们将使用 generate 秘籍 及其关联的 配置。
让我们首先将配置文件复制到我们的本地工作目录,以便进行更改。
$ tune cp generation ./custom_generation_config.yaml
Copied file to custom_generation_config.yaml
$ mkdir /tmp/torchtune/llama3_2_3B/lora_single_device/out
- 让我们修改
custom_generation_config.yaml
以包含以下更改。同样,您只需要 替换两个字段:
output_dir
和checkpoint_files
checkpoint_dir: /tmp/torchtune/llama3_2_3B/lora_single_device/epoch_0
output_dir: /tmp/torchtune/llama3_2_3B/lora_single_device/out
# Tokenizer
tokenizer:
_component_: torchtune.models.llama3.llama3_tokenizer
path: ${checkpoint_dir}/original/tokenizer.model
prompt_template: null
model:
# Notice that we don't pass the lora model. We are using the merged weights,
_component_: torchtune.models.llama3_2.llama3_2_3b
checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: ${checkpoint_dir}
checkpoint_files: [
model-00001-of-00002.safetensors,
model-00002-of-00002.safetensors,
]
output_dir: ${output_dir}
model_type: LLAMA3_2
### OTHER PARAMETERS -- NOT RELATED TO THIS CHECKPOINT
device: cuda
dtype: bf16
seed: 1234
# Generation arguments; defaults taken from gpt-fast
prompt:
system: null
user: "Tell me a joke. "
max_new_tokens: 300
temperature: 0.6 # 0.8 and 0.6 are popular values to try
top_k: 300
enable_kv_cache: True
quantizer: null
配置更新后,让我们开始生成!我们将使用默认的采样设置:top_k=300
和 temperature=0.8
。这些参数控制采样概率的计算方式。我们建议在使用这些参数进行试验之前,先用它们检查模型。
$ tune run generate --config ./custom_generation_config.yaml prompt.user="Tell me a joke. "
Tell me a joke. Here's a joke for you:
What do you call a fake noodle?
An impasta!
引入一些量化¶
我们依靠 torchao 进行 训练后量化。安装 torchao 后,要量化微调好的模型,我们可以运行以下命令:
# we also support `int8_weight_only()` and `int8_dynamic_activation_int8_weight()`, see
# https://github.com/pytorch/ao/tree/main/torchao/quantization#other-available-quantization-techniques
# for a full list of techniques that we support
from torchao.quantization.quant_api import quantize_, int4_weight_only
quantize_(model, int4_weight_only())
量化后,我们依靠 torch.compile
来加速。更多详情请参阅 此示例用法。
torchao 还提供了 此表格,列出了 llama2
和 llama3
的性能和准确性结果。
对于 Llama 模型,您可以在 torchao 中直接在量化模型上使用其 generate.py
脚本运行生成,如 本 readme 中所述。这样您就可以将自己的结果与之前链接表格中的结果进行比较。
在实际应用中使用您的模型¶
假设您对目前模型的表现很满意 - 您想用它做点什么!将其用于生产服务,发布到 Hugging Face Hub 等。由于我们处理检查点转换,您可以直接使用标准格式。
与 Hugging Face from_pretrained()
一起使用¶
情况 1:Hugging Face 使用基础模型 + 训练好的适配器
在这里,我们从 Hugging Face 模型中心加载基础模型。然后,我们使用 PeftModel 在其之上加载适配器。它将查找用于权重的 adapter_model.safetensors
文件和用于插入位置的 adapter_config.json
文件。
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
#TODO: update it to your chosen epoch
trained_model_path = "/tmp/torchtune/llama3_2_3B/lora_single_device/epoch_0"
# Define the model and adapter paths
original_model_name = "meta-llama/Llama-3.2-1B-Instruct"
model = AutoModelForCausalLM.from_pretrained(original_model_name)
# huggingface will look for adapter_model.safetensors and adapter_config.json
peft_model = PeftModel.from_pretrained(model, trained_model_path)
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(original_model_name)
# Function to generate text
def generate_text(model, tokenizer, prompt, max_length=50):
inputs = tokenizer(prompt, return_tensors="pt")
outputs = model.generate(**inputs, max_length=max_length)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
prompt = "tell me a joke: '"
print("Base model output:", generate_text(peft_model, tokenizer, prompt))
情况 2:Hugging Face 使用合并后的权重
在这种情况下,Hugging Face 将检查 model.safetensors.index.json
文件,确定应该加载哪些文件。
from transformers import AutoModelForCausalLM, AutoTokenizer
#TODO: update it to your chosen epoch
trained_model_path = "/tmp/torchtune/llama3_2_3B/lora_single_device/epoch_0"
model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path=trained_model_path,
)
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(trained_model_path, safetensors=True)
# Function to generate text
def generate_text(model, tokenizer, prompt, max_length=50):
inputs = tokenizer(prompt, return_tensors="pt")
outputs = model.generate(**inputs, max_length=max_length)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
prompt = "Complete the sentence: 'Once upon a time...'"
print("Base model output:", generate_text(model, tokenizer, prompt))
与 vLLM 一起使用¶
vLLM 是一个用于 LLM 推理和服务的快速易用库。它包含许多很棒的功能,例如最先进的服务吞吐量、传入请求的连续批处理、量化和推测解码。
该库将加载任何 .safetensors
文件。由于我们已经合并了完整的模型权重和适配器权重,我们可以安全地删除(或移动)适配器权重,以免 vLLM 因这些文件而混淆。
rm /tmp/torchtune/llama3_2_3B/lora_single_device/base_model/adapter_model.safetensors
现在我们可以运行以下脚本
from vllm import LLM, SamplingParams
def print_outputs(outputs):
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
print("-" * 80)
#TODO: update it to your chosen epoch
llm = LLM(
model="/tmp/torchtune/llama3_2_3B/lora_single_device/epoch_0",
load_format="safetensors",
kv_cache_dtype="auto",
)
sampling_params = SamplingParams(max_tokens=16, temperature=0.5)
conversation = [
{"role": "system", "content": "You are a helpful assistant"},
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hello! How can I assist you today?"},
{
"role": "user",
"content": "Write an essay about the importance of higher education.",
},
]
outputs = llm.chat(conversation, sampling_params=sampling_params, use_tqdm=False)
print_outputs(outputs)
将您的模型上传到 Hugging Face Hub¶
您的新模型运行良好,您想与全世界分享。最简单的方法是利用 huggingface_hub。
import huggingface_hub
api = huggingface_hub.HfApi()
#TODO: update it to your chosen epoch
trained_model_path = "/tmp/torchtune/llama3_2_3B/lora_single_device/epoch_0"
username = huggingface_hub.whoami()["name"]
repo_name = "my-model-trained-with-torchtune"
# if the repo doesn't exist
repo_id = huggingface_hub.create_repo(repo_name).repo_id
# if it already exists
repo_id = f"{username}/{repo_name}"
api.upload_folder(
folder_path=trained_model_path,
repo_id=repo_id,
repo_type="model",
create_pr=False
)
如果您愿意,也可以尝试使用命令行版本 huggingface-cli upload。
希望本教程为您提供了关于如何将 torchtune 用于您自己的工作流程的一些见解。祝您调优愉快!