作者:Linda Wang, Evan Smothers, Kartikay Khandelwal

在这篇博客中,我们介绍了使用 torchtune 的知识蒸馏配方将 Llama 3.1 8B 模型蒸馏到 Llama 3.2 1B 模型中的案例研究。我们展示了如何在训练后使用知识蒸馏(KD)来提高指令遵循任务的性能,并展示了用户如何利用该配方。

什么是知识蒸馏?

知识蒸馏是一种广泛使用的压缩技术,它将知识从一个较大的(教师)模型转移到一个较小的(学生)模型。较大的模型拥有更多的参数和知识容量,然而,这种较大的容量在部署时计算成本也更高。知识蒸馏可用于将较大模型的知识压缩到较小的模型中。其思想是通过从较大模型的输出中学习来提高较小模型的性能。

知识蒸馏如何工作?

通过在迁移集上训练,将知识从教师模型转移到学生模型,学生模型被训练来模仿教师模型的 token 级别概率分布。假设教师模型分布与迁移数据集相似。下图是 KD 工作原理的简化表示。

Figure 1: Simplified representation of knowledge transfer from teacher to student model

图 1:知识从教师模型转移到学生模型的简化表示

由于大型语言模型(LLMs)的知识蒸馏是一个活跃的研究领域,有诸如 MiniLLMDistiLLMAKLGeneralized KD 等论文研究了不同的损失方法。在本案例研究中,我们以标准交叉熵(CE)损失与前向 Kullback-Leibler (KL) 散度损失作为基线。前向 KL 散度旨在通过强制学生模型的分布与教师模型的所有分布对齐来最小化差异。

为什么知识蒸馏有用?

知识蒸馏的思想是,通过使用教师模型的输出来作为附加信号,较小的模型可以比从头开始训练或仅使用有监督微调取得更好的性能。例如,Llama 3.2 轻量级 1B 和 3B 文本模型纳入了 Llama 3.1 8B 和 70B 的 logits,以在剪枝后恢复性能。此外,对于指令遵循任务的微调,LLM 蒸馏研究表明,知识蒸馏方法可以超越单独的有监督微调(SFT)。

模型 方法 DollyEval Self-Inst S-NI
GPT-4 Eval GPT-4 Eval Rouge-L
Llama 7B SFT 73.0 69.2 32.4
KD 73.7 70.5 33.7
MiniLLM 76.4 73.1 35.5
Llama 1.1B SFT 22.1 - 27.8
KD 22.2 - 28.1
AKL 24.4 - 31.4
OpenLlama 3B SFT 47.3 41.7 29.3
KD 44.9 42.1 27.9
SeqKD 48.1 46.0 29.1
DistiLLM 59.9 53.3 37.6

表 1:知识蒸馏方法与有监督微调的比较

下面是知识蒸馏与有监督微调区别的简化示例。

有监督微调 知识蒸馏
   
model = llama3_2_1b()
ce_loss = CrossEntropyLoss()
kd_loss = ForwardKLLoss()

tokens, labels = batch["tokens"], batch["labels"]
logits = model(tokens, ...)

loss = ce_loss(logits, labels)
loss.backward()

   
   
   
model = llama3_2_1b()
teacher_model = llama3_1_8b()
ce_loss = CrossEntropyLoss()
kd_loss = ForwardKLLoss()

tokens, labels = batch["tokens"], batch["labels"]
logits = model(tokens, ...)
teacher_logits = teacher_model(tokens, ...)
loss = ce_loss(logits, labels) + kd_loss(logits, teacher_logits, labels)
loss.backward()
   
   

torchtune 中的 KD 配方

使用 torchtune,我们可以轻松地将知识蒸馏应用于 Llama3 以及其他 LLM 模型系列,通过 torchtune 的 KD 配方。此配方的目标是通过从 Llama3.1-8B 中蒸馏,在 Alpaca 指令遵循数据集上微调 Llama3.2-1B。此配方专注于训练后阶段,并假定教师和学生模型已经过预训练。

首先,我们必须下载模型权重。为了与其他 torchtune 微调配置保持一致,我们将使用 Llama3.1-8B 的指令微调模型作为教师模型,Llama3.2-1B 作为学生模型。

tune download meta-llama/Meta-Llama-3.1-8B-Instruct --output-dir /tmp/Meta-Llama-3.1-8B-Instruct --ignore-patterns "original/consolidated.00.pth" --hf_token <HF_TOKEN>

tune download meta-llama/Llama-3.2-1B-Instruct --output-dir /tmp/Llama-3.2-1B-Instruct --ignore-patterns "original/consolidated.00.pth" --hf_token <HF_TOKEN>

为了使教师模型的分布与 Alpaca 数据集相似,我们将使用 LoRA 微调教师模型。根据我们在下一节中展示的实验,我们发现当教师模型已经在目标数据集上进行过微调时,KD 的性能会更好。

tune run lora_finetune_single_device --config llama3_1/8B_lora_single_device

最后,我们可以运行以下命令,在单个 GPU 上将微调后的 8B 模型蒸馏到 1B 模型中。在本案例研究中,我们使用了单个 A100 80GB GPU。我们还提供了用于在多个设备上运行的分布式配方

tune run knowledge_distillation_single_device --config llama3_2/knowledge_distillation_single_device

消融研究

在本节中,我们将演示更改配置和超参数如何影响性能。默认情况下,我们的配置使用 LoRA 微调后的 8B 教师模型、下载的 1B 学生模型、学习率为 3e-4 和 KD 损失比例为 0.5。在本案例研究中,我们在 alpaca_cleaned_dataset 数据集上进行了微调,并通过 EleutherAI 的 LM evaluation harnesstruthfulqa_mc2hellaswagcommonsense_qa 任务上评估了模型。让我们看看以下因素的影响:

  1. 使用微调后的教师模型
  2. 使用微调后的学生模型
  3. KD 损失比例和学习率的超参数调优

使用微调后的教师模型

配置文件中的默认设置使用微调后的教师模型。现在,让我们看看不先微调教师模型的效果。

从损失来看,使用基线 8B 作为教师模型会导致比使用微调后教师模型更高的损失。KD 损失也相对保持不变,这表明教师模型应该与迁移数据集具有相同的分布。

Figure 2: (left to right) KD loss from forward KL divergence, class loss from cross entropy, total loss: even combination of KD and class loss.

图 2:(从左到右)前向 KL 散度的 KD 损失,交叉熵的类别损失,总损失:KD 损失和类别损失的平均组合。

在我们的基准测试中,我们可以看到 1B 模型的有监督微调比基线 1B 模型获得了更好的准确率。通过使用微调后的 8B 教师模型,我们在 truthfulqa 上看到了可比的结果,并在 hellaswag 和 commonsense 上看到了提升。当使用基线 8B 作为教师模型时,我们在所有指标上都看到了提升,但低于其他配置。

模型 TruthfulQA hellaswag commonsense
mc2 acc acc_norm acc
基线 Llama 3.1 8B 0.5401 0.5911 0.7915 0.7707
使用 LoRA 微调后的 Llama 3.1 8B 0.5475 0.6031 0.7951 0.7789
基线 Llama 3.2 1B 0.4384 0.4517 0.6064 0.5536
使用 LoRA 微调后的 Llama 3.2 1B 0.4492 0.4595 0.6132 0.5528
使用基线 8B 作为教师模型的 KD 0.444 0.4576 0.6123 0.5561
使用微调后 8B 作为教师模型的 KD 0.4481 0.4603 0.6157 0.5569

表 2:使用基线和微调后 8B 作为教师模型的比较

使用微调后的学生模型

在这些实验中,我们研究了当学生模型已微调时 KD 的效果。我们分析了使用基线和微调后的 8B 和 1B 模型的不同组合的效果。

根据损失图,无论学生模型是否经过微调,使用微调后的教师模型都会导致较低的损失。同样值得注意的是,当使用微调后的学生模型时,类别损失开始增加。

Figure 3: Comparing losses of different teacher and student model initializations

图 3:比较不同教师和学生模型初始化时的损失

使用微调后的学生模型进一步提升了 truthfulqa 的准确率,但在 hellaswag 和 commonsense 上的准确率有所下降。使用微调后的教师模型和基线学生模型在 hellaswag 和 commonsense 数据集上取得了最佳结果。根据这些发现,最佳配置将取决于您要优化的评估数据集和指标。

模型 TruthfulQA hellaswag commonsense
mc2 acc acc_norm acc
基线 Llama 3.1 8B 0.5401 0.5911 0.7915 0.7707
使用 LoRA 微调后的 Llama 3.1 8B 0.5475 0.6031 0.7951 0.7789
基线 Llama 3.2 1B 0.4384 0.4517 0.6064 0.5536
使用 LoRA 微调后的 Llama 3.2 1B 0.4492 0.4595 0.6132 0.5528
使用基线 8B 和基线 1B 的 KD 0.444 0.4576 0.6123 0.5561
使用基线 8B 和微调后 1B 的 KD 0.4508 0.448 0.6004 0.5274
使用微调后 8B 和基线 1B 的 KD 0.4481 0.4603 0.6157 0.5569
使用微调后 8B 和微调后 1B 的 KD 0.4713 0.4512 0.599 0.5233

表 3:使用基线和微调后的教师和学生模型的比较

超参数调优:学习率

默认情况下,该配方的学习率为 3e-4。在这些实验中,我们将学习率从高达 1e-3 更改为低至 1e-5。

根据损失图,除 1e-5 外,所有学习率都导致相似的损失,而 1e-5 的 KD 损失和类别损失更高。

Figure 4: Comparing losses of different learning rates

图 4:比较不同学习率的损失

根据我们的基准测试,最优学习率取决于您要优化的指标和任务。

模型 学习率 TruthfulQA hellaswag commonsense
mc2 acc acc_norm acc
基线 Llama 3.1 8B - 0.5401 0.5911 0.7915 0.7707
使用 LoRA 微调后的 Llama 3.1 8B - 0.5475 0.6031 0.7951 0.7789
基线 Llama 3.2 1B - 0.4384 0.4517 0.6064 0.5536
使用 LoRA 微调后的 Llama 3.2 1B - 0.4492 0.4595 0.6132 0.5528
使用微调后 8B 和基线 1B 的 KD 3e-4 0.4481 0.4603 0.6157 0.5569
使用微调后 8B 和基线 1B 的 KD 1e-3 0.4453 0.4535 0.6071 0.5258
使用微调后 8B 和基线 1B 的 KD 1e-4 0.4489 0.4606 0.6156 0.5586
使用微调后 8B 和基线 1B 的 KD 1e-5 0.4547 0.4548 0.6114 0.5487

表 4:调优学习率的效果

超参数调优:KD 比例

默认情况下,KD 比例设置为 0.5,这意味着类别损失和 KD 损失的权重相等。在这些实验中,我们研究了不同 KD 比例的影响,其中 0 只使用类别损失,而 1 只使用 KD 损失。

总体而言,基准测试结果表明,对于这些任务和指标,较高的 KD 比例表现略好。

模型 kd_ratio (lr=3e-4) TruthfulQA hellaswag commonsense
mc2 acc acc_norm acc
基线 Llama 3.1 8B - 0.5401 0.5911 0.7915 0.7707
使用 LoRA 微调后的 Llama 3.1 8B - 0.5475 0.6031 0.7951 0.7789
基线 Llama 3.2 1B - 0.4384 0.4517 0.6064 0.5536
使用 LoRA 微调后的 Llama 3.2 1B - 0.4492 0.4595 0.6132 0.5528
使用微调后 8B 和基线 1B 的 KD 0.25 0.4485 0.4595 0.6155 0.5602
使用微调后 8B 和基线 1B 的 KD 0.5 0.4481 0.4603 0.6157 0.5569
使用微调后 8B 和基线 1B 的 KD 0.75 0.4543 0.463 0.6189 0.5643
使用微调后 8B 和基线 1B 的 KD 1.0 0.4537 0.4641 0.6177 0.5717

表 5:调优 KD 比例的效果

展望未来

在这篇博客中,我们介绍了一项关于如何使用 torchtune 通过前向 KL 散度损失对 Llama 3.1 8B 和 Llama 3.2 1B 的 logits 进行 LLM 蒸馏的研究。未来还有许多探索方向,可以进一步提高性能并在蒸馏方法中提供更多灵活性。

  • 扩展 KD 损失种类。KD 配方使用了前向 KL 散度损失。然而,如前所述,将学生模型的分布与整个教师模型的分布对齐可能效率不高。有多种论文,例如 MiniLLMDistiLLMGeneralized KD,引入了新的 KD 损失和策略来解决这一限制,并已证明优于标准使用带有前向 KL 散度损失的交叉熵。例如,MiniLLM 使用反向 KL 散度来防止学生模型高估教师模型的低概率区域。DistiLLM 引入了一种偏斜的 KL 损失和自适应训练策略。
  • 实现跨 tokenizer 蒸馏。当前配方要求教师模型和学生模型使用相同的 tokenizer,这限制了跨不同 LLM 系列进行蒸馏的能力。目前已有关于跨 tokenizer 方法的研究(例如 Universal Logit Distillation),我们可以进行探索。
  • 将蒸馏扩展到多模态 LLMs 和编码器模型。KD 配方的一个自然扩展是将其扩展到多模态 LLMs。与部署更高效的 LLMs 类似,也需要部署更小、更高效的多模态 LLMs。此外,也有工作证明了 LLMs 可以用作编码器模型(例如 LLM2Vec)。将用作编码器的 LLMs 蒸馏到更小的编码器模型中,也可能是值得探索的有前景的方向。