• 文档 >
  • 使用知识蒸馏将 Llama3.1 8B 蒸馏到 Llama3.2 1B
快捷方式

使用知识蒸馏将 Llama3.1 8B 蒸馏到 Llama3.2 1B

本指南将向您介绍知识蒸馏(KD)并展示如何使用 torchtune 将 Llama3.1 8B 模型蒸馏到 Llama3.2 1B。如果您已经了解知识蒸馏是什么,并想直接在 torchtune 中运行自己的蒸馏,您可以跳到torchtune 中的 KD recipe教程。

您将学到什么
  • 什么是 KD 以及它如何帮助提升模型性能

  • torchtune 中 KD 组件概述

  • 如何使用 torchtune 从教师模型蒸馏到学生模型

  • 如何尝试不同的 KD 配置

先决条件

什么是知识蒸馏?

知识蒸馏是一种广泛使用的压缩技术,它将知识从较大的(教师)模型转移到较小的(学生)模型。较大的模型拥有更多的参数和知识容量,然而,这种更大的容量在部署时也更计算昂贵。知识蒸馏可用于将较大模型的知识压缩到较小的模型中。其思想是,较小模型的性能可以通过学习较大模型的输出来得到提升。

知识蒸馏如何工作?

知识通过在一个迁移集上训练学生模型来模仿教师模型的 token 级概率分布,从而从教师模型转移到学生模型。下图是 KD 如何工作的简化表示。

../_images/kd-simplified.png

总损失可以通过多种方式配置。torchtune 中默认的 KD 配置将交叉熵 (CE) 损失与前向 Kullback-Leibler (KL) 散度损失结合,这在标准 KD 方法中使用。前向 KL 散度旨在通过强制学生模型的分布与教师模型的所有分布对齐来最小化差异。然而,将学生模型的分布与整个教师模型的分布对齐可能效果不佳,并且有一些论文,例如 MiniLLMDistiLLMGeneralized KD,引入了新的 KD 损失来解决这些限制。对于本教程,让我们看看前向 KL 散度损失的实现。

import torch
import torch.nn.functional as F

class ForwardKLLoss(torch.nn.Module):
  def __init__(self, ignore_index: int = -100)
    super().__init__()
    self.ignore_index = ignore_index

  def forward(self, student_logits, teacher_logits, labels) -> torch.Tensor:
    # Implementation from https://github.com/jongwooko/distillm
    # Computes the softmax of the teacher logits
    teacher_prob = F.softmax(teacher_logits, dim=-1, dtype=torch.float32)
    # Computes the student log softmax probabilities
    student_logprob = F.log_softmax(student_logits, dim=-1, dtype=torch.float32)
    # Computes the forward KL divergence
    prod_probs = teacher_prob * student_logprob
    # Compute the sum
    x = torch.sum(prod_probs, dim=-1).view(-1)
    # We don't want to include the ignore labels in the average
    mask = (labels != self.ignore_index).int()
    # Loss is averaged over non-ignored targets
    return -torch.sum(x * mask.view(-1), dim=0) / torch.sum(mask.view(-1), dim=0)

为简化计算省略了一些细节,但如果您想了解更多,可以在ForwardKLLoss中查看实现。默认情况下,KD 配置使用ForwardKLWithChunkedOutputLoss来减少内存。当前实现仅支持输出 logit 形状和分词器相同的学生模型和教师模型。

torchtune 中的 KD recipe

使用 torchtune,我们可以轻松地将知识蒸馏应用于 Llama3 以及其他 LLM 模型系列。让我们来看看如何使用 torchtune 的KD recipe来蒸馏模型。

首先,确保您已下载所有模型权重。在此示例中,我们将使用 Llama3.1-8B 作为教师模型,Llama3.2-1B 作为学生模型。

tune download meta-llama/Meta-Llama-3.1-8B-Instruct --output-dir /tmp/Meta-Llama-3.1-8B-Instruct --ignore-patterns "original/consolidated.00.pth" --hf_token <HF_TOKEN>

tune download meta-llama/Llama-3.2-1B-Instruct --output-dir /tmp/Llama-3.2-1B-Instruct --ignore-patterns "original/consolidated.00.pth" --hf_token <HF_TOKEN>

然后,我们将使用 LoRA 对教师模型进行微调。根据我们的实验和先前的工作,我们发现当教师模型已在目标数据集上进行微调时,KD 的性能更好。

tune run lora_finetune_single_device --config llama3_1/8B_lora_single_device

最后,我们可以运行以下命令,在单个 GPU 上将微调后的 8B 模型蒸馏到 1B 模型。

tune run knowledge_distillation_single_device --config llama3_2/knowledge_distillation_single_device

消融研究

在前面的示例中,我们使用了 LoRA 微调的 8B 教师模型和基线 1B 学生模型,但我们可能希望尝试不同的配置和超参数。在本教程中,我们将在alpaca_cleaned_dataset上进行微调,并通过 EleutherAI LM evaluation harnesstruthfulqa_mc2hellaswagcommonsense_qa 任务上评估模型。让我们看看以下因素的影响:

  1. 使用微调的教师模型

  2. 使用微调的学生模型

  3. 超参数调优:kd_ratio 和学习率

  4. 参数数量更接近的教师模型和学生模型

使用微调的教师模型

配置中的默认设置使用微调的教师模型。现在,让我们看看不先微调教师模型的效果。要更改教师模型,您可以修改配置中的teacher_checkpointer

teacher_checkpointer:
  _component_: torchtune.training.FullModelHFCheckpointer
  checkpoint_dir: /tmp/Meta-Llama-3.1-8B-Instruct/
  checkpoint_files: [
      model-00001-of-00004.safetensors,
      model-00002-of-00004.safetensors,
      model-00003-of-00004.safetensors,
      model-00004-of-00004.safetensors
  ]

在下表中,我们可以看到,对 1B 模型进行标准微调比基线 1B 模型获得了更好的准确性。通过使用微调后的 8B 教师模型,我们在 truthfulqa 上看到了可比的结果,并在 hellaswag 和 commonsense 上看到了改进。当使用基线 8B 作为教师模型时,我们在所有指标上都看到了改进,但低于其他配置。

../_images/kd-finetune-teacher.png

从损失来看,使用基线 8B 作为教师模型比使用微调的教师模型导致更高的损失。KD 损失也保持相对恒定,这表明教师模型应与迁移数据集具有相同的分布。

使用微调的学生模型

对于这些实验,让我们看看当学生模型已经进行微调时,KD 的效果。在这些实验中,我们研究了基线和微调的 8B 和 1B 模型不同组合。要更改学生模型,您可以先微调 1B 模型,然后修改配置中的学生模型 checkpointer

checkpointer:
   _component_: torchtune.training.FullModelHFCheckpointer
   checkpoint_dir: /tmp/Llama-3.2-1B-Instruct/
   checkpoint_files: [
     hf_model_0001_0.pt
   ]

使用微调的学生模型进一步提高了 truthfulqa 的准确性,但 hellaswag 和 commonsense 的准确性有所下降。使用微调的教师模型和基线学生模型在 hellaswag 和 commonsense 数据集上取得了最佳结果。根据这些发现,最佳配置将取决于您优化的评估数据集和指标。

../_images/kd-finetune-student.png

根据损失图,无论学生模型是否经过微调,使用微调的教师模型都会导致更低的损失。另外值得注意的是,当使用微调的学生模型时,类别损失开始增加。

超参数调优:学习率

默认情况下,配置中的学习率为 \(3e^{-4}\),与 LoRA 配置相同。对于这些实验,我们将学习率从高达 \(1e^{-3}\) 更改为低至 \(1e^{-5}\)。要更改学习率,只需使用以下方法覆盖学习率参数即可:

tune run knowledge_distillation_single_device --config llama3_2/knowledge_distillation_single_device optimizer.lr=1e-3

根据结果,最优学习率会根据您优化的指标而改变。

../_images/kd-hyperparam-lr.png

根据损失图,除了学习率为 \(1e^{-5}\) 外,所有学习率都产生了相似的损失,而 \(1e^{-5}\) 的 KD 损失和类别损失较高。

超参数调优:KD 比例

在配置中,我们将kd_ratio设置为 0.5,这为类别损失和 KD 损失赋予了相等的权重。在这些实验中,我们考察了不同 KD 比例的影响,其中 0 表示仅使用类别损失,1 表示仅使用 KD 损失。与更改学习率类似,KD 比例可以使用以下方法进行调整:

tune run knowledge_distillation_single_device --config llama3_2/knowledge_distillation_single_device kd_ratio=0.25

总体而言,较高的 KD 比例的评估结果略好。

../_images/kd-hyperparam-kd-ratio.png

Qwen2 1.5B 到 0.5B

KD recipe 也可应用于不同的模型系列。在这里,我们考察了当教师模型和学生模型之间的参数数量更接近时 KD 的效果。对于此实验,我们使用了 Qwen2 1.5B 和 Qwen2 0.5B,其配置文件可在qwen2/knowledge_distillation_single_device配置中找到。在这里,我们看到在 alpaca cleaned 数据集上训练仅提高了 truthful_qa 的性能,而其他评估任务的指标有所下降。对于 truthful_qa,KD 将学生模型性能提高了 5.8%,而微调仅提高了 1.3%。

../_images/kd-qwen2-res.png

文档

查阅全面的 PyTorch 开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题解答

查看资源