• 文档 >
  • 使用知识蒸馏将 Llama3.1 8B 蒸馏到 Llama3.2 1B
快捷方式

使用知识蒸馏将 Llama3.1 8B 蒸馏到 Llama3.2 1B

本指南将向您介绍知识蒸馏 (KD) 并向您展示如何使用 torchtune 将 Llama3.1 8B 模型蒸馏到 Llama3.2 1B。如果您已经了解知识蒸馏是什么,并且希望直接在 torchtune 中运行自己的蒸馏,您可以跳到torchtune 中的 KD 食谱教程。

您将学到什么
  • 什么是 KD 以及它如何帮助提高模型性能

  • torchtune 中 KD 组件的概述

  • 如何使用 torchtune 从教师模型到学生模型进行蒸馏

  • 如何尝试不同的 KD 配置

先决条件

什么是知识蒸馏?

知识蒸馏是一种广泛使用的压缩技术,它将知识从更大的(教师)模型转移到更小的(学生)模型。更大的模型具有更多的参数和知识容量,但是,这种更大的容量在部署方面也更耗费计算资源。知识蒸馏可用于将更大模型的知识压缩到更小的模型中。其思想是,可以通过学习更大模型的输出来提高较小模型的性能。

知识蒸馏是如何工作的?

通过在传输集上训练学生模型来将知识从教师模型转移到学生模型,其中学生模型被训练以模仿教师模型的标记级概率分布。下图简化了 KD 的工作原理。

../_images/kd-simplified.png

总损失可以通过多种方式配置。torchtune 中的默认 KD 配置将交叉熵 (CE) 损失与前向Kullback-Leibler (KL) 散度损失结合使用,这在标准 KD 方法中使用。前向 KL 散度旨在通过强制学生分布与所有教师分布对齐来最小化差异。但是,将学生分布与整个教师分布对齐可能并不有效,并且有多篇论文,如MiniLLMDistiLLMGeneralized KD,引入了新的 KD 损失来解决这些限制。在本教程中,让我们看看前向 KL 散度损失的实现。

import torch
import torch.nn.functional as F

class ForwardKLLoss(torch.nn.Module):
  def __init__(self, ignore_index: int = -100)
    super().__init__()
    self.ignore_index = ignore_index

  def forward(self, student_logits, teacher_logits, labels) -> torch.Tensor:
    # Implementation from https://github.com/jongwooko/distillm
    # Computes the softmax of the teacher logits
    teacher_prob = F.softmax(teacher_logits, dim=-1, dtype=torch.float32)
    # Computes the student log softmax probabilities
    student_logprob = F.log_softmax(student_logits, dim=-1, dtype=torch.float32)
    # Computes the forward KL divergence
    prod_probs = teacher_prob * student_logprob
    # Compute the sum
    x = torch.sum(prod_probs, dim=-1).view(-1)
    # We don't want to include the ignore labels in the average
    mask = (labels != self.ignore_index).int()
    # Loss is averaged over non-ignored targets
    return -torch.sum(x * mask.view(-1), dim=0) / torch.sum(mask.view(-1), dim=0)

为了简化计算,省略了一些细节,但如果您想了解更多信息,可以在ForwardKLLoss中查看实现。默认情况下,KD 配置使用ForwardKLWithChunkedOutputLoss来减少内存。当前实现仅支持输出 logits 形状和分词器相同的学生和教师模型。

torchtune 中的 KD 食谱

使用 torchtune,我们可以轻松地将知识蒸馏应用于 Llama3 以及其他 LLM 模型系列。让我们看看如何使用 torchtune 的KD 食谱蒸馏模型。

首先,确保您已下载所有模型权重。在本例中,我们将使用 Llama3.1-8B 作为教师模型,Llama3.2-1B 作为学生模型。

tune download meta-llama/Meta-Llama-3.1-8B-Instruct --output-dir /tmp/Meta-Llama-3.1-8B-Instruct --ignore-patterns "original/consolidated.00.pth" --hf_token <HF_TOKEN>

tune download meta-llama/Llama-3.2-1B-Instruct --output-dir /tmp/Llama-3.2-1B-Instruct --ignore-patterns "original/consolidated.00.pth" --hf_token <HF_TOKEN>

然后,我们将使用 LoRA 微调教师模型。根据我们的实验和之前的工作,我们发现当教师模型已经在目标数据集上进行微调时,KD 的性能更好。

tune run lora_finetune_single_device --config llama3_1/8B_lora_single_device

最后,我们可以运行以下命令在单个 GPU 上将微调后的 8B 模型蒸馏到 1B 模型中。

tune run knowledge_distillation_single_device --config llama3_2/knowledge_distillation_single_device

消融研究

在前面的示例中,我们使用了 LoRA 微调的 8B 教师模型和基线 1B 学生模型,但我们可能希望尝试一些不同的配置和超参数。在本教程中,我们将对alpaca_cleaned_dataset进行微调,并在truthfulqa_mc2hellaswagcommonsense_qa任务上评估模型,这些任务通过 EleutherAI LM 评估工具进行评估。让我们看看以下方面的影响

  1. 使用微调的教师模型

  2. 使用微调的学生模型

  3. kd_ratio 和学习率的超参数调整

  4. 参数数量更接近的教师和学生模型

使用微调的教师模型

配置中的默认设置使用微调的教师模型。现在,让我们看看不首先微调教师模型的影响。要更改教师模型,您可以修改配置中的teacher_checkpointer

teacher_checkpointer:
  _component_: torchtune.training.FullModelHFCheckpointer
  checkpoint_dir: /tmp/Meta-Llama-3.1-8B-Instruct/
  checkpoint_files: [
      model-00001-of-00004.safetensors,
      model-00002-of-00004.safetensors,
      model-00003-of-00004.safetensors,
      model-00004-of-00004.safetensors
  ]

在下表中,我们可以看到 1B 模型的标准微调比基线 1B 模型实现了更高的准确率。通过使用微调的 8B 教师模型,我们看到 truthfulqa 的结果相当,hellaswag 和 commonsense 的结果有所提高。当使用基线 8B 作为教师时,我们看到所有指标都得到了改善,但低于其他配置。

../_images/kd-finetune-teacher.png

查看损失,使用基线 8B 作为教师会导致比使用微调教师模型更高的损失。KD 损失也保持相对恒定,这表明教师模型应该与传输数据集具有相同的分布。

使用微调的学生模型

在这些实验中,让我们看看当学生模型已经过微调时,KD 的效果。在这些实验中,我们查看了基线模型和微调后的 8B 和 1B 模型的不同组合。要更改学生模型,您可以先微调 1B 模型,然后修改配置中的学生模型检查点。

checkpointer:
   _component_: torchtune.training.FullModelHFCheckpointer
   checkpoint_dir: /tmp/Llama-3.2-1B-Instruct/
   checkpoint_files: [
     hf_model_0001_0.pt
   ]

使用微调的学生模型可以进一步提高 truthfulqa 的准确性,但 hellaswag 和 commonsense 的准确性下降了。使用微调的教师模型和基线学生模型在 hellaswag 和 commonsense 数据集上取得了最佳结果。根据这些发现,最佳配置将根据您要优化的评估数据集和指标而改变。

../_images/kd-finetune-student.png

根据损失图,使用微调的教师模型会导致更低的损失,无论学生模型是否经过微调。有趣的是,当使用微调的学生模型时,分类损失开始增加。

超参数调整:学习率

默认情况下,配置中的学习率为 \(3e^{-4}\),与 LoRA 配置相同。在这些实验中,我们将学习率从高达 \(1e^{-3}\) 降低到低至 \(1e^{-5}\)。要更改学习率,您可以简单地使用以下方法覆盖学习率参数:

tune run knowledge_distillation_single_device --config llama3_2/knowledge_distillation_single_device optimizer.lr=1e-3

根据结果,最佳学习率会根据您要优化的指标而改变。

../_images/kd-hyperparam-lr.png

根据损失图,所有学习率导致的损失都相似,除了 \(1e^{-5}\),它的 KD 损失和分类损失更高。

超参数调整:KD 比率

在配置中,我们将 kd_ratio 设置为 0.5,这使得分类损失和 KD 损失的权重相等。在这些实验中,我们查看了不同 KD 比率的影响,其中 0 仅使用分类损失,1 仅使用 KD 损失。与更改学习率类似,可以使用以下方法调整 KD 比率:

tune run knowledge_distillation_single_device --config llama3_2/knowledge_distillation_single_device kd_ratio=0.25

总体而言,评估结果对于较高的 KD 比率略好。

../_images/kd-hyperparam-kd-ratio.png

Qwen2 1.5B 到 0.5B

KD 方法也可以应用于不同的模型系列。在这里,我们查看当教师模型和学生模型之间的参数数量更接近时,KD 的效果。在这个实验中,我们使用了 Qwen2 1.5B 和 Qwen2 0.5B,它们的配置可以在 qwen2/knowledge_distillation_single_device 配置中找到。在这里我们看到,在 alpaca 清理后的数据集上进行训练只会提高 truthful_qa 的性能,并降低其他评估任务的指标。对于 truthful_qa,KD 将学生模型的性能提高了 5.8%,而微调将性能提高了 1.3%。

../_images/kd-qwen2-res.png

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取适合初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源