跳转到主要内容
博客

在 torchtune 中将 Llama3.1 8B 精炼为 1B

作者: 2024年11月18日2025年5月5日无评论

在本博客中,我们将展示一个案例研究,介绍如何使用torchtune的知识蒸馏秘籍将Llama 3.1 8B模型蒸馏为Llama 3.2 1B。我们将演示知识蒸馏(KD)如何在训练后用于提高指令遵循任务的性能,并展示用户如何利用此秘籍。

什么是知识蒸馏?

知识蒸馏是一种广泛使用的压缩技术,它将知识从一个较大的(教师)模型转移到一个较小的(学生)模型。较大的模型具有更多的参数和知识容量,但这种较大的容量在部署时也更具计算成本。知识蒸馏可用于将较大模型的知识压缩到较小的模型中。其思想是,通过从较大模型的输出中学习,可以提高较小模型的性能。

知识蒸馏如何工作?

知识通过在迁移集上进行训练,从教师模型转移到学生模型,学生模型被训练模仿教师模型的标记级概率分布。假设教师模型分布与迁移数据集相似。下图是KD工作原理的简化表示。

Figure 1: Simplified representation of knowledge transfer from teacher to student model

图1:知识从教师模型向学生模型转移的简化表示

由于LLM的知识蒸馏是一个活跃的研究领域,因此有许多论文,例如MiniLLMDistiLLMAKLGeneralized KD,研究了不同的损失方法。在本案例研究中,我们重点关注以标准交叉熵(CE)损失和前向Kullback-Leibler(KL)散度损失作为基线。前向KL散度旨在通过强制学生的分布与教师的所有分布对齐来最小化差异。

为什么知识蒸馏有用?

知识蒸馏的理念是,与从头开始训练或仅进行监督微调相比,小型模型可以使用教师模型的输出作为额外信号,从而获得更好的性能。例如,Llama 3.2 轻量级 1B 和 3B 文本模型融入了 Llama 3.1 8B 和 70B 的 logits,以在剪枝后恢复性能。此外,对于指令遵循任务的微调,LLM 蒸馏研究表明,知识蒸馏方法可以优于单独的监督微调 (SFT)。

模型 方法 DollyEval 自指导 S-NI
GPT-4评估 GPT-4评估 Rouge-L
Llama 7B SFT 73.0 69.2 32.4
KD 73.7 70.5 33.7
MiniLLM 76.4 73.1 35.5
Llama 1.1B SFT 22.1 27.8
KD 22.2 28.1
AKL 24.4 31.4
OpenLlama 3B SFT 47.3 41.7 29.3
KD 44.9 42.1 27.9
SeqKD 48.1 46.0 29.1
DistiLLM 59.9 53.3 37.6

表1:知识蒸馏方法与监督微调的比较

以下是知识蒸馏与监督微调有何不同的简化示例。

监督微调 知识蒸馏
模型 = llama3_2_1b() ce_loss = CrossEntropyLoss() kd_loss = ForwardKLLoss() tokens, labels = batch["tokens"], batch["labels"] logits = model(tokens, ...) loss = ce_loss(logits, labels) loss.backward() 模型 = llama3_2_1b() teacher_model = llama3_1_8b() ce_loss = CrossEntropyLoss() kd_loss = ForwardKLLoss() tokens, labels = batch["tokens"], batch["labels"] logits = model(tokens, ...) teacher_logits = teacher_model(tokens, ...) loss = ce_loss(logits, labels) + kd_loss(logits, teacher_logits, labels) loss.backward()

torchtune中的KD秘籍

借助torchtune,我们可以使用torchtune的KD秘籍,轻松将知识蒸馏应用于Llama3以及其他LLM模型系列。此秘籍的目标是通过从Llama3.1-8B中蒸馏,在Alpaca指令遵循数据集上微调Llama3.2-1B。此秘籍侧重于训练后,并假设教师和学生模型已预训练。

首先,我们需要下载模型权重。为了与其他torchtune微调配置保持一致,我们将使用Llama3.1-8B的指令微调模型作为教师模型,Llama3.2-1B作为学生模型。

tune download meta-llama/Meta-Llama-3.1-8B-Instruct --output-dir /tmp/Meta-Llama-3.1-8B-Instruct --ignore-patterns "original/consolidated.00.pth" --hf_token <HF_TOKEN>

tune download meta-llama/Llama-3.2-1B-Instruct --output-dir /tmp/Llama-3.2-1B-Instruct --ignore-patterns "original/consolidated.00.pth" --hf_token <HF_TOKEN>

为了使教师模型分布与Alpaca数据集相似,我们将使用LoRA微调教师模型。根据我们在下一节中展示的实验,我们发现当教师模型已在目标数据集上微调时,KD的性能更好。

tune run lora_finetune_single_device --config llama3_1/8B_lora_single_device

最后,我们可以运行以下命令,在单个 GPU 上将微调后的 8B 模型蒸馏为 1B 模型。在本案例研究中,我们使用了单个 A100 80GB GPU。我们还有一份分布式秘籍,用于在多个设备上运行。

tune run knowledge_distillation_single_device --config llama3_2/knowledge_distillation_single_device

消融研究

在本节中,我们将演示更改配置和超参数如何影响性能。默认情况下,我们的配置使用LoRA微调的8B教师模型、下载的1B学生模型、3e-4的学习率和0.5的KD损失比率。在本案例研究中,我们对alpaca_cleaned_dataset进行了微调,并通过EleutherAI的LM评估框架,在truthfulqa_mc2hellaswagcommonsense_qa任务上评估了模型。我们来看一下以下因素的影响:

  1. 使用微调后的教师模型
  2. 使用微调后的学生模型
  3. KD损失比和学习率的超参数调整

使用微调后的教师模型

配置中的默认设置使用微调后的教师模型。现在,我们来看看不首先微调教师模型的影响。

从损失来看,使用基线8B作为教师模型会导致比使用微调后的教师模型更高的损失。KD损失也保持相对不变,这表明教师模型应该与迁移数据集具有相同的分布。

Figure 2: (left to right) KD loss from forward KL divergence, class loss from cross entropy, total loss: even combination of KD and class loss.

图2:(从左到右)前向KL散度的KD损失、交叉熵的类别损失、总损失:KD和类别损失的均匀组合。

在我们的基准测试中,我们可以看到1B模型的监督微调比基线1B模型获得了更好的准确性。通过使用微调后的8B教师模型,我们在truthfulqa上看到了可比较的结果,并在hellaswag和commonsense上有所改进。当使用基线8B作为教师模型时,我们在所有指标上都看到了改进,但低于其他配置。

模型 TruthfulQA hellaswag commonsense
mc2 acc acc_norm acc
基线 Llama 3.1 8B 0.5401 0.5911 0.7915 0.7707
使用LoRA微调的Llama 3.1 8B 0.5475 0.6031 0.7951 0.7789
基线Llama 3.2 1B 0.4384 0.4517 0.6064 0.5536
使用LoRA微调的Llama 3.2 1B 0.4492 0.4595 0.6132 0.5528
使用基线8B作为教师模型的KD 0.444 0.4576 0.6123 0.5561
使用微调8B作为教师模型的KD 0.4481 0.4603 0.6157 0.5569

表2:使用基线和微调8B作为教师模型的比较

使用微调后的学生模型

对于这些实验,我们考察了学生模型已微调时KD的效果。我们分析了使用基线和微调的8B和1B模型不同组合的效果。

根据损失图,使用微调的教师模型会产生较低的损失,无论学生模型是否经过微调。同样值得注意的是,当使用微调的学生模型时,类别损失开始增加。

Figure 3: Comparing losses of different teacher and student model initializations

图3:比较不同教师和学生模型初始化下的损失

使用微调后的学生模型进一步提高了truthfulqa的准确性,但hellaswag和commonsense的准确性有所下降。使用微调后的教师模型和基线学生模型在hellaswag和commonsense数据集上取得了最佳结果。根据这些发现,最佳配置将根据您优化的评估数据集和指标而变化。

模型 TruthfulQA hellaswag commonsense
mc2 acc acc_norm acc
基线 Llama 3.1 8B 0.5401 0.5911 0.7915 0.7707
使用LoRA微调的Llama 3.1 8B 0.5475 0.6031 0.7951 0.7789
基线Llama 3.2 1B 0.4384 0.4517 0.6064 0.5536
使用LoRA微调的Llama 3.2 1B 0.4492 0.4595 0.6132 0.5528
使用基线8B和基线1B的KD 0.444 0.4576 0.6123 0.5561
使用基线8B和微调1B的KD 0.4508 0.448 0.6004 0.5274
使用微调8B和基线1B的KD 0.4481 0.4603 0.6157 0.5569
使用微调8B和微调1B的KD 0.4713 0.4512 0.599 0.5233

表3:使用基线和微调的教师和学生模型进行比较

超参数调优:学习率

默认情况下,该秘籍的学习率为3e-4。对于这些实验,我们将学习率从高至1e-3调整至低至1e-5。

根据损失图,除了1e-5的学习率导致KD和类别损失更高外,所有学习率都导致了相似的损失。

Figure 4: Comparing losses of different learning rates

图4:比较不同学习率下的损失

根据我们的基准测试,最佳学习率会根据您正在优化的指标和任务而变化。

模型 学习率 TruthfulQA hellaswag commonsense
mc2 acc acc_norm acc
基线 Llama 3.1 8B 0.5401 0.5911 0.7915 0.7707
使用LoRA微调的Llama 3.1 8B 0.5475 0.6031 0.7951 0.7789
基线Llama 3.2 1B 0.4384 0.4517 0.6064 0.5536
使用LoRA微调的Llama 3.2 1B 0.4492 0.4595 0.6132 0.5528
使用微调8B和基线1B的KD 3e-4 0.4481 0.4603 0.6157 0.5569
使用微调8B和基线1B的KD 1e-3 0.4453 0.4535 0.6071 0.5258
使用微调8B和基线1B的KD 1e-4 0.4489 0.4606 0.6156 0.5586
使用微调8B和基线1B的KD 1e-5 0.4547 0.4548 0.6114 0.5487

表4:调整学习率的效果

超参数调优:KD比率

默认情况下,KD 比率设置为 0.5,这使得类别损失和 KD 损失的权重相同。在这些实验中,我们研究了不同 KD 比率的影响,其中 0 仅使用类别损失,1 仅使用 KD 损失。

总的来说,基准测试结果表明,对于这些任务和指标,更高的KD比率性能略好。

模型 kd_ratio (lr=3e-4) TruthfulQA hellaswag commonsense
mc2 acc acc_norm acc
基线 Llama 3.1 8B 0.5401 0.5911 0.7915 0.7707
使用LoRA微调的Llama 3.1 8B 0.5475 0.6031 0.7951 0.7789
基线Llama 3.2 1B 0.4384 0.4517 0.6064 0.5536
使用LoRA微调的Llama 3.2 1B 0.4492 0.4595 0.6132 0.5528
使用微调8B和基线1B的KD 0.25 0.4485 0.4595 0.6155 0.5602
使用微调8B和基线1B的KD 0.5 0.4481 0.4603 0.6157 0.5569
使用微调8B和基线1B的KD 0.75 0.4543 0.463 0.6189 0.5643
使用微调8B和基线1B的KD 1.0 0.4537 0.4641 0.6177 0.5717

表5:调整KD比率的效果

展望未来

在本博客中,我们研究了如何通过 torchtune 使用前向 KL 散度损失对 Llama 3.1 8B 和 Llama 3.2 1B logit 进行 LLM 蒸馏。未来有许多方向可以进一步探索,以提高性能并提供蒸馏方法的更多灵活性。

  • 扩展KD损失方案。KD秘籍使用了前向KL散度损失。然而,如上所述,将学生分布与整个教师分布对齐可能效果不佳。有几篇论文,例如MiniLLMDistiLLMGeneralized KD,引入了新的KD损失和策略来解决这一限制,并已证明其性能优于标准的使用交叉熵与前向KL散度损失。例如,MiniLLM使用反向KL散度来防止学生高估教师的低概率区域。DistiLLM引入了偏斜KL损失和自适应训练策略。
  • 实现跨分词器蒸馏。目前的秘籍要求教师模型和学生模型使用相同的分词器,这限制了在不同LLM系列之间进行蒸馏的能力。目前已有关于跨分词器方法的研究(例如通用Logit蒸馏),我们可以进行探索。
  • 将蒸馏扩展到多模态LLM和编码器模型。KD秘籍的一个自然扩展是将其扩展到多模态LLM。与部署更高效的LLM类似,也有必要部署更小、更高效的多模态LLM。此外,已有工作表明LLM可以作为编码器模型(例如LLM2Vec)。将LLM作为编码器蒸馏到更小的编码器模型可能也是一个有前景的探索方向。