博客

在 torchtune 中将 Llama3.1 8B 精炼为 1B

作者: 2024年11月18日2025年5月5日暂无评论

在本篇博客中,我们介绍了一个案例研究:利用 torchtune 的知识蒸馏配方(recipe)将 Llama 3.1 8B 模型蒸馏为 Llama 3.2 1B。我们展示了知识蒸馏(KD)如何在训练后阶段用于提升模型执行指令任务的性能,并演示了用户如何使用该配方。

什么是知识蒸馏?

知识蒸馏是一种广泛使用的压缩技术,旨在将知识从较大的(教师)模型迁移到较小的(学生)模型中。较大的模型拥有更多的参数和知识容量,然而,这种巨大的容量也导致其部署成本更高。知识蒸馏可用于将大模型的知识压缩到小模型中。其核心思想是,通过学习大模型的输出,可以提升小模型的性能。

知识蒸馏是如何工作的?

通过在迁移集(transfer set)上进行训练,知识得以从教师模型迁移到学生模型,在训练过程中,学生模型被要求模仿教师模型的 token 级概率分布。其前提是教师模型的分布与迁移数据集相似。下图简化表示了知识蒸馏的工作原理。

Figure 1: Simplified representation of knowledge transfer from teacher to student model

图 1:从教师模型到学生模型的知识迁移简化示意图

由于大型语言模型(LLM)的知识蒸馏是一个活跃的研究领域,已有包括 MiniLLMDistiLLMAKLGeneralized KD 在内的多篇论文研究了不同的损失函数方法。在本案例研究中,我们以标准交叉熵(CE)损失与前向 Kullback-Leibler (KL) 散度损失作为基准。前向 KL 散度旨在通过强制使学生模型的分布与教师模型的所有分布对齐,来最小化差异。

为什么知识蒸馏很有用?

知识蒸馏的核心理念是:与从零开始训练或仅通过监督微调(SFT)相比,小模型可以利用教师模型的输出作为额外信号,从而获得更好的性能。例如,Llama 3.2 轻量级 1B 和 3B 文本模型结合了来自 Llama 3.1 8B 和 70B 的逻辑值(logits),以便在剪枝后恢复性能。此外,对于指令遵循任务的微调,LLM 蒸馏研究表明,知识蒸馏方法的效果可以优于单纯的监督微调(SFT)。

模型 方法 DollyEval Self-Inst S-NI
GPT-4 评估 GPT-4 评估 Rouge-L
Llama 7B SFT 73.0 69.2 32.4
KD 73.7 70.5 33.7
MiniLLM 76.4 73.1 35.5
Llama 1.1B SFT 22.1 27.8
KD 22.2 28.1
AKL 24.4 31.4
OpenLlama 3B SFT 47.3 41.7 29.3
KD 44.9 42.1 27.9
SeqKD 48.1 46.0 29.1
DistiLLM 59.9 53.3 37.6

表 1:知识蒸馏方法与监督微调的对比

以下是知识蒸馏与监督微调区别的简化示例。

监督微调 知识蒸馏
model = llama3_2_1b() ce_loss = CrossEntropyLoss() kd_loss = ForwardKLLoss() tokens, labels = batch["tokens"], batch["labels"] logits = model(tokens, ...) loss = ce_loss(logits, labels) loss.backward() model = llama3_2_1b() teacher_model = llama3_1_8b() ce_loss = CrossEntropyLoss() kd_loss = ForwardKLLoss() tokens, labels = batch["tokens"], batch["labels"] logits = model(tokens, ...) teacher_logits = teacher_model(tokens, ...) loss = ce_loss(logits, labels) + kd_loss(logits, teacher_logits, labels) loss.backward()

torchtune 中的 KD 配方

借助 torchtune,我们可以使用其 KD 配方轻松将知识蒸馏应用于 Llama3 以及其他 LLM 模型家族。该配方的目标是通过从 Llama3.1-8B 蒸馏,在 Alpaca 指令遵循数据集上微调 Llama3.2-1B。该配方侧重于训练后阶段,并假定教师和学生模型都已完成预训练。

首先,我们需要下载模型权重。为了与其他 torchtune 微调配置保持一致,我们将使用已完成指令调优的 Llama3.1-8B 作为教师模型,将 Llama3.2-1B 作为学生模型。

tune download meta-llama/Meta-Llama-3.1-8B-Instruct --output-dir /tmp/Meta-Llama-3.1-8B-Instruct --ignore-patterns "original/consolidated.00.pth" --hf_token <HF_TOKEN>

tune download meta-llama/Llama-3.2-1B-Instruct --output-dir /tmp/Llama-3.2-1B-Instruct --ignore-patterns "original/consolidated.00.pth" --hf_token <HF_TOKEN>

为了使教师模型的分布与 Alpaca 数据集相似,我们将使用 LoRA 对教师模型进行微调。根据下一节中展示的实验,我们发现当教师模型已经在目标数据集上进行过微调时,KD 的效果更好。

tune run lora_finetune_single_device --config llama3_1/8B_lora_single_device

最后,我们可以运行以下命令,在单个 GPU 上将微调后的 8B 模型蒸馏为 1B 模型。在本案例研究中,我们使用了一块 A100 80GB GPU。我们还提供了用于在多设备上运行的 分布式配方

tune run knowledge_distillation_single_device --config llama3_2/knowledge_distillation_single_device

消融研究

在本节中,我们演示了改变配置和超参数如何影响性能。默认情况下,我们的配置使用经过 LoRA 微调的 8B 教师模型、下载的 1B 学生模型、3e-4 的学习率以及 0.5 的 KD 损失比。在本案例研究中,我们在 alpaca_cleaned_dataset 上进行了微调,并通过 EleutherAI 的 LM 评估工具套件truthfulqa_mc2hellaswagcommonsense_qa 任务上评估了模型。让我们来看看以下因素的影响:

  1. 使用微调后的教师模型
  2. 使用微调后的学生模型
  3. KD 损失比和学习率的超参数调优

使用微调后的教师模型

配置中的默认设置使用了微调后的教师模型。现在,让我们看看不先对教师模型进行微调会有什么影响。

观察损失情况,使用基础版 8B 作为教师模型产生的损失比使用微调后的教师模型更高。KD 损失也保持相对恒定,这表明教师模型应当与迁移数据集具有相同的分布。

Figure 2: (left to right) KD loss from forward KL divergence, class loss from cross entropy, total loss: even combination of KD and class loss.

图 2:(从左至右)前向 KL 散度带来的 KD 损失、交叉熵带来的分类损失、总损失:KD 损失与分类损失的均衡组合。

在我们的基准测试中,可以看到 1B 模型的监督微调效果优于基础版 1B 模型。通过使用微调后的 8B 教师模型,我们在 truthfulqa 上看到了相当的结果,并在 hellaswag 和 commonsense 上取得了提升。当使用基础版 8B 作为教师时,虽然所有指标都有改善,但低于其他配置的效果。

模型 TruthfulQA hellaswag commonsense
mc2 acc acc_norm acc
基础版 Llama 3.1 8B 0.5401 0.5911 0.7915 0.7707
使用 LoRA 微调的 Llama 3.1 8B 0.5475 0.6031 0.7951 0.7789
基础版 Llama 3.2 1B 0.4384 0.4517 0.6064 0.5536
使用 LoRA 微调的 Llama 3.2 1B 0.4492 0.4595 0.6132 0.5528
使用基础版 8B 作为教师的 KD 0.444 0.4576 0.6123 0.5561
使用微调版 8B 作为教师的 KD 0.4481 0.4603 0.6157 0.5569

表 2:使用基础版与微调版 8B 作为教师模型的对比

使用微调后的学生模型

在这些实验中,我们观察了当学生模型已经完成微调时,KD 的影响。我们分析了基础版和微调版 8B 与 1B 模型不同组合的效果。

根据损失图,无论学生模型是否经过微调,使用微调后的教师模型都会导致较低的损失。同样值得注意的是,在使用微调后的学生模型时,分类损失开始增加。

Figure 3: Comparing losses of different teacher and student model initializations

图 3:比较不同教师和学生模型初始化方式的损失

使用微调后的学生模型进一步提高了 truthfulqa 的准确率,但在 hellaswag 和 commonsense 上准确率有所下降。使用微调后的教师模型和基础版学生模型在 hellaswag 和 commonsense 数据集上达到了最佳效果。基于这些发现,最佳配置将根据您所优化的评估数据集和指标而变化。

模型 TruthfulQA hellaswag commonsense
mc2 acc acc_norm acc
基础版 Llama 3.1 8B 0.5401 0.5911 0.7915 0.7707
使用 LoRA 微调的 Llama 3.1 8B 0.5475 0.6031 0.7951 0.7789
基础版 Llama 3.2 1B 0.4384 0.4517 0.6064 0.5536
使用 LoRA 微调的 Llama 3.2 1B 0.4492 0.4595 0.6132 0.5528
使用基础版 8B 和基础版 1B 的 KD 0.444 0.4576 0.6123 0.5561
使用基础版 8B 和微调版 1B 的 KD 0.4508 0.448 0.6004 0.5274
使用微调版 8B 和基础版 1B 的 KD 0.4481 0.4603 0.6157 0.5569
使用微调版 8B 和微调版 1B 的 KD 0.4713 0.4512 0.599 0.5233

表 3:使用基础版与微调版教师和学生模型的对比

超参数调优:学习率

默认情况下,该配方的学习率为 3e-4。在这些实验中,我们将学习率从 1e-3 调整到 1e-5。

根据损失图,除了 1e-5 学习率导致更高的 KD 和分类损失外,所有学习率产生的损失相似。

Figure 4: Comparing losses of different learning rates

图 4:比较不同学习率的损失

根据我们的基准测试,最佳学习率取决于您所优化的指标和任务。

模型 学习率 TruthfulQA hellaswag commonsense
mc2 acc acc_norm acc
基础版 Llama 3.1 8B 0.5401 0.5911 0.7915 0.7707
使用 LoRA 微调的 Llama 3.1 8B 0.5475 0.6031 0.7951 0.7789
基础版 Llama 3.2 1B 0.4384 0.4517 0.6064 0.5536
使用 LoRA 微调的 Llama 3.2 1B 0.4492 0.4595 0.6132 0.5528
使用微调版 8B 和基础版 1B 的 KD 3e-4 0.4481 0.4603 0.6157 0.5569
使用微调版 8B 和基础版 1B 的 KD 1e-3 0.4453 0.4535 0.6071 0.5258
使用微调版 8B 和基础版 1B 的 KD 1e-4 0.4489 0.4606 0.6156 0.5586
使用微调版 8B 和基础版 1B 的 KD 1e-5 0.4547 0.4548 0.6114 0.5487

表 4:调整学习率的影响

超参数调优:KD 比率

默认情况下,KD 比率设置为 0.5,这给分类损失和 KD 损失赋予了相等的权重。在这些实验中,我们研究了不同 KD 比率的影响,其中 0 表示仅使用分类损失,1 表示仅使用 KD 损失。

总体而言,基准测试结果表明,对于这些任务和指标,较高的 KD 比率表现略好。

模型 kd_ratio (lr=3e-4) TruthfulQA hellaswag commonsense
mc2 acc acc_norm acc
基础版 Llama 3.1 8B 0.5401 0.5911 0.7915 0.7707
使用 LoRA 微调的 Llama 3.1 8B 0.5475 0.6031 0.7951 0.7789
基础版 Llama 3.2 1B 0.4384 0.4517 0.6064 0.5536
使用 LoRA 微调的 Llama 3.2 1B 0.4492 0.4595 0.6132 0.5528
使用微调版 8B 和基础版 1B 的 KD 0.25 0.4485 0.4595 0.6155 0.5602
使用微调版 8B 和基础版 1B 的 KD 0.5 0.4481 0.4603 0.6157 0.5569
使用微调版 8B 和基础版 1B 的 KD 0.75 0.4543 0.463 0.6189 0.5643
使用微调版 8B 和基础版 1B 的 KD 1.0 0.4537 0.4641 0.6177 0.5717

表 5:调整 KD 比率的影响

展望未来

在本篇博客中,我们介绍了如何通过 torchtune 在 Llama 3.1 8B 和 Llama 3.2 1B 的逻辑值上使用前向 KL 散度损失来蒸馏 LLM 的研究。未来还有许多方向值得探索,以进一步提高性能并为蒸馏方法提供更大的灵活性。

  • 扩展 KD 损失方案。KD 配方目前使用前向 KL 散度损失。然而,如上所述,使学生分布与整个教师分布对齐可能并不有效。已有包括 MiniLLMDistiLLMGeneralized KD 在内的多篇论文引入了新的 KD 损失函数和策略,以解决这一局限性,并被证明优于标准交叉熵配合前向 KL 散度损失的方法。例如,MiniLLM 使用反向 KL 散度来防止学生模型过高估计教师模型中概率较低的区域。DistiLLM 引入了倾斜 KL 损失(skewed KL loss)和自适应训练策略。
  • 启用跨分词器(cross-tokenizer)蒸馏。当前的配方要求教师和学生模型使用相同的分词器,这限制了在不同 LLM 家族之间进行蒸馏的能力。目前已有关于跨分词器方法的研究(例如 Universal Logit Distillation),这是我们可以进一步探索的方向。
  • 将蒸馏扩展到多模态 LLM 和编码器模型。KD 配方的一个自然延伸是将应用扩展到多模态 LLM。与部署更高效的 LLM 类似,我们也需要部署更小、更高效的多模态 LLM。此外,已有研究展示了将 LLM 作为编码器模型(例如 LLM2Vec)。从作为编码器的 LLM 蒸馏到较小的编码器模型也是一个值得探索的前景广阔的方向。