在本篇博客中,我们介绍了一个案例研究:利用 torchtune 的知识蒸馏配方(recipe)将 Llama 3.1 8B 模型蒸馏为 Llama 3.2 1B。我们展示了知识蒸馏(KD)如何在训练后阶段用于提升模型执行指令任务的性能,并演示了用户如何使用该配方。
什么是知识蒸馏?
知识蒸馏是一种广泛使用的压缩技术,旨在将知识从较大的(教师)模型迁移到较小的(学生)模型中。较大的模型拥有更多的参数和知识容量,然而,这种巨大的容量也导致其部署成本更高。知识蒸馏可用于将大模型的知识压缩到小模型中。其核心思想是,通过学习大模型的输出,可以提升小模型的性能。
知识蒸馏是如何工作的?
通过在迁移集(transfer set)上进行训练,知识得以从教师模型迁移到学生模型,在训练过程中,学生模型被要求模仿教师模型的 token 级概率分布。其前提是教师模型的分布与迁移数据集相似。下图简化表示了知识蒸馏的工作原理。

图 1:从教师模型到学生模型的知识迁移简化示意图
由于大型语言模型(LLM)的知识蒸馏是一个活跃的研究领域,已有包括 MiniLLM、DistiLLM、AKL 和 Generalized KD 在内的多篇论文研究了不同的损失函数方法。在本案例研究中,我们以标准交叉熵(CE)损失与前向 Kullback-Leibler (KL) 散度损失作为基准。前向 KL 散度旨在通过强制使学生模型的分布与教师模型的所有分布对齐,来最小化差异。
为什么知识蒸馏很有用?
知识蒸馏的核心理念是:与从零开始训练或仅通过监督微调(SFT)相比,小模型可以利用教师模型的输出作为额外信号,从而获得更好的性能。例如,Llama 3.2 轻量级 1B 和 3B 文本模型结合了来自 Llama 3.1 8B 和 70B 的逻辑值(logits),以便在剪枝后恢复性能。此外,对于指令遵循任务的微调,LLM 蒸馏研究表明,知识蒸馏方法的效果可以优于单纯的监督微调(SFT)。
| 模型 | 方法 | DollyEval | Self-Inst | S-NI |
| GPT-4 评估 | GPT-4 评估 | Rouge-L | ||
| Llama 7B | SFT | 73.0 | 69.2 | 32.4 |
| KD | 73.7 | 70.5 | 33.7 | |
| MiniLLM | 76.4 | 73.1 | 35.5 | |
| Llama 1.1B | SFT | 22.1 | – | 27.8 |
| KD | 22.2 | – | 28.1 | |
| AKL | 24.4 | – | 31.4 | |
| OpenLlama 3B | SFT | 47.3 | 41.7 | 29.3 |
| KD | 44.9 | 42.1 | 27.9 | |
| SeqKD | 48.1 | 46.0 | 29.1 | |
| DistiLLM | 59.9 | 53.3 | 37.6 |
表 1:知识蒸馏方法与监督微调的对比
以下是知识蒸馏与监督微调区别的简化示例。
| 监督微调 | 知识蒸馏 |
|---|---|
model = llama3_2_1b() ce_loss = CrossEntropyLoss() kd_loss = ForwardKLLoss() tokens, labels = batch["tokens"], batch["labels"] logits = model(tokens, ...) loss = ce_loss(logits, labels) loss.backward() |
model = llama3_2_1b() teacher_model = llama3_1_8b() ce_loss = CrossEntropyLoss() kd_loss = ForwardKLLoss() tokens, labels = batch["tokens"], batch["labels"] logits = model(tokens, ...) teacher_logits = teacher_model(tokens, ...) loss = ce_loss(logits, labels) + kd_loss(logits, teacher_logits, labels) loss.backward() |
torchtune 中的 KD 配方
借助 torchtune,我们可以使用其 KD 配方轻松将知识蒸馏应用于 Llama3 以及其他 LLM 模型家族。该配方的目标是通过从 Llama3.1-8B 蒸馏,在 Alpaca 指令遵循数据集上微调 Llama3.2-1B。该配方侧重于训练后阶段,并假定教师和学生模型都已完成预训练。
首先,我们需要下载模型权重。为了与其他 torchtune 微调配置保持一致,我们将使用已完成指令调优的 Llama3.1-8B 作为教师模型,将 Llama3.2-1B 作为学生模型。
tune download meta-llama/Meta-Llama-3.1-8B-Instruct --output-dir /tmp/Meta-Llama-3.1-8B-Instruct --ignore-patterns "original/consolidated.00.pth" --hf_token <HF_TOKEN>
tune download meta-llama/Llama-3.2-1B-Instruct --output-dir /tmp/Llama-3.2-1B-Instruct --ignore-patterns "original/consolidated.00.pth" --hf_token <HF_TOKEN>
为了使教师模型的分布与 Alpaca 数据集相似,我们将使用 LoRA 对教师模型进行微调。根据下一节中展示的实验,我们发现当教师模型已经在目标数据集上进行过微调时,KD 的效果更好。
tune run lora_finetune_single_device --config llama3_1/8B_lora_single_device
最后,我们可以运行以下命令,在单个 GPU 上将微调后的 8B 模型蒸馏为 1B 模型。在本案例研究中,我们使用了一块 A100 80GB GPU。我们还提供了用于在多设备上运行的 分布式配方。
tune run knowledge_distillation_single_device --config llama3_2/knowledge_distillation_single_device
消融研究
在本节中,我们演示了改变配置和超参数如何影响性能。默认情况下,我们的配置使用经过 LoRA 微调的 8B 教师模型、下载的 1B 学生模型、3e-4 的学习率以及 0.5 的 KD 损失比。在本案例研究中,我们在 alpaca_cleaned_dataset 上进行了微调,并通过 EleutherAI 的 LM 评估工具套件在 truthfulqa_mc2、hellaswag 和 commonsense_qa 任务上评估了模型。让我们来看看以下因素的影响:
- 使用微调后的教师模型
- 使用微调后的学生模型
- KD 损失比和学习率的超参数调优
使用微调后的教师模型
配置中的默认设置使用了微调后的教师模型。现在,让我们看看不先对教师模型进行微调会有什么影响。
观察损失情况,使用基础版 8B 作为教师模型产生的损失比使用微调后的教师模型更高。KD 损失也保持相对恒定,这表明教师模型应当与迁移数据集具有相同的分布。

图 2:(从左至右)前向 KL 散度带来的 KD 损失、交叉熵带来的分类损失、总损失:KD 损失与分类损失的均衡组合。
在我们的基准测试中,可以看到 1B 模型的监督微调效果优于基础版 1B 模型。通过使用微调后的 8B 教师模型,我们在 truthfulqa 上看到了相当的结果,并在 hellaswag 和 commonsense 上取得了提升。当使用基础版 8B 作为教师时,虽然所有指标都有改善,但低于其他配置的效果。
| 模型 | TruthfulQA | hellaswag | commonsense | |
| mc2 | acc | acc_norm | acc | |
| 基础版 Llama 3.1 8B | 0.5401 | 0.5911 | 0.7915 | 0.7707 |
| 使用 LoRA 微调的 Llama 3.1 8B | 0.5475 | 0.6031 | 0.7951 | 0.7789 |
| 基础版 Llama 3.2 1B | 0.4384 | 0.4517 | 0.6064 | 0.5536 |
| 使用 LoRA 微调的 Llama 3.2 1B | 0.4492 | 0.4595 | 0.6132 | 0.5528 |
| 使用基础版 8B 作为教师的 KD | 0.444 | 0.4576 | 0.6123 | 0.5561 |
| 使用微调版 8B 作为教师的 KD | 0.4481 | 0.4603 | 0.6157 | 0.5569 |
表 2:使用基础版与微调版 8B 作为教师模型的对比
使用微调后的学生模型
在这些实验中,我们观察了当学生模型已经完成微调时,KD 的影响。我们分析了基础版和微调版 8B 与 1B 模型不同组合的效果。
根据损失图,无论学生模型是否经过微调,使用微调后的教师模型都会导致较低的损失。同样值得注意的是,在使用微调后的学生模型时,分类损失开始增加。

图 3:比较不同教师和学生模型初始化方式的损失
使用微调后的学生模型进一步提高了 truthfulqa 的准确率,但在 hellaswag 和 commonsense 上准确率有所下降。使用微调后的教师模型和基础版学生模型在 hellaswag 和 commonsense 数据集上达到了最佳效果。基于这些发现,最佳配置将根据您所优化的评估数据集和指标而变化。
| 模型 | TruthfulQA | hellaswag | commonsense | |
| mc2 | acc | acc_norm | acc | |
| 基础版 Llama 3.1 8B | 0.5401 | 0.5911 | 0.7915 | 0.7707 |
| 使用 LoRA 微调的 Llama 3.1 8B | 0.5475 | 0.6031 | 0.7951 | 0.7789 |
| 基础版 Llama 3.2 1B | 0.4384 | 0.4517 | 0.6064 | 0.5536 |
| 使用 LoRA 微调的 Llama 3.2 1B | 0.4492 | 0.4595 | 0.6132 | 0.5528 |
| 使用基础版 8B 和基础版 1B 的 KD | 0.444 | 0.4576 | 0.6123 | 0.5561 |
| 使用基础版 8B 和微调版 1B 的 KD | 0.4508 | 0.448 | 0.6004 | 0.5274 |
| 使用微调版 8B 和基础版 1B 的 KD | 0.4481 | 0.4603 | 0.6157 | 0.5569 |
| 使用微调版 8B 和微调版 1B 的 KD | 0.4713 | 0.4512 | 0.599 | 0.5233 |
表 3:使用基础版与微调版教师和学生模型的对比
超参数调优:学习率
默认情况下,该配方的学习率为 3e-4。在这些实验中,我们将学习率从 1e-3 调整到 1e-5。
根据损失图,除了 1e-5 学习率导致更高的 KD 和分类损失外,所有学习率产生的损失相似。

图 4:比较不同学习率的损失
根据我们的基准测试,最佳学习率取决于您所优化的指标和任务。
| 模型 | 学习率 | TruthfulQA | hellaswag | commonsense | |
| mc2 | acc | acc_norm | acc | ||
| 基础版 Llama 3.1 8B | – | 0.5401 | 0.5911 | 0.7915 | 0.7707 |
| 使用 LoRA 微调的 Llama 3.1 8B | – | 0.5475 | 0.6031 | 0.7951 | 0.7789 |
| 基础版 Llama 3.2 1B | – | 0.4384 | 0.4517 | 0.6064 | 0.5536 |
| 使用 LoRA 微调的 Llama 3.2 1B | – | 0.4492 | 0.4595 | 0.6132 | 0.5528 |
| 使用微调版 8B 和基础版 1B 的 KD | 3e-4 | 0.4481 | 0.4603 | 0.6157 | 0.5569 |
| 使用微调版 8B 和基础版 1B 的 KD | 1e-3 | 0.4453 | 0.4535 | 0.6071 | 0.5258 |
| 使用微调版 8B 和基础版 1B 的 KD | 1e-4 | 0.4489 | 0.4606 | 0.6156 | 0.5586 |
| 使用微调版 8B 和基础版 1B 的 KD | 1e-5 | 0.4547 | 0.4548 | 0.6114 | 0.5487 |
表 4:调整学习率的影响
超参数调优:KD 比率
默认情况下,KD 比率设置为 0.5,这给分类损失和 KD 损失赋予了相等的权重。在这些实验中,我们研究了不同 KD 比率的影响,其中 0 表示仅使用分类损失,1 表示仅使用 KD 损失。
总体而言,基准测试结果表明,对于这些任务和指标,较高的 KD 比率表现略好。
| 模型 | kd_ratio (lr=3e-4) | TruthfulQA | hellaswag | commonsense | |
| mc2 | acc | acc_norm | acc | ||
| 基础版 Llama 3.1 8B | – | 0.5401 | 0.5911 | 0.7915 | 0.7707 |
| 使用 LoRA 微调的 Llama 3.1 8B | – | 0.5475 | 0.6031 | 0.7951 | 0.7789 |
| 基础版 Llama 3.2 1B | – | 0.4384 | 0.4517 | 0.6064 | 0.5536 |
| 使用 LoRA 微调的 Llama 3.2 1B | – | 0.4492 | 0.4595 | 0.6132 | 0.5528 |
| 使用微调版 8B 和基础版 1B 的 KD | 0.25 | 0.4485 | 0.4595 | 0.6155 | 0.5602 |
| 使用微调版 8B 和基础版 1B 的 KD | 0.5 | 0.4481 | 0.4603 | 0.6157 | 0.5569 |
| 使用微调版 8B 和基础版 1B 的 KD | 0.75 | 0.4543 | 0.463 | 0.6189 | 0.5643 |
| 使用微调版 8B 和基础版 1B 的 KD | 1.0 | 0.4537 | 0.4641 | 0.6177 | 0.5717 |
表 5:调整 KD 比率的影响
展望未来
在本篇博客中,我们介绍了如何通过 torchtune 在 Llama 3.1 8B 和 Llama 3.2 1B 的逻辑值上使用前向 KL 散度损失来蒸馏 LLM 的研究。未来还有许多方向值得探索,以进一步提高性能并为蒸馏方法提供更大的灵活性。
- 扩展 KD 损失方案。KD 配方目前使用前向 KL 散度损失。然而,如上所述,使学生分布与整个教师分布对齐可能并不有效。已有包括 MiniLLM、DistiLLM 和 Generalized KD 在内的多篇论文引入了新的 KD 损失函数和策略,以解决这一局限性,并被证明优于标准交叉熵配合前向 KL 散度损失的方法。例如,MiniLLM 使用反向 KL 散度来防止学生模型过高估计教师模型中概率较低的区域。DistiLLM 引入了倾斜 KL 损失(skewed KL loss)和自适应训练策略。
- 启用跨分词器(cross-tokenizer)蒸馏。当前的配方要求教师和学生模型使用相同的分词器,这限制了在不同 LLM 家族之间进行蒸馏的能力。目前已有关于跨分词器方法的研究(例如 Universal Logit Distillation),这是我们可以进一步探索的方向。
- 将蒸馏扩展到多模态 LLM 和编码器模型。KD 配方的一个自然延伸是将应用扩展到多模态 LLM。与部署更高效的 LLM 类似,我们也需要部署更小、更高效的多模态 LLM。此外,已有研究展示了将 LLM 作为编码器模型(例如 LLM2Vec)。从作为编码器的 LLM 蒸馏到较小的编码器模型也是一个值得探索的前景广阔的方向。