• 文档 >
  • 使用知识蒸馏将 Llama3.1 8B 提炼成 Llama3.2 1B
快捷方式

使用知识蒸馏将 Llama3.1 8B 提炼成 Llama3.2 1B

本指南将教您关于知识蒸馏 (KD) 的知识,并向您展示如何使用 torchtune 将 Llama3.1 8B 模型提炼成 Llama3.2 1B。如果您已经了解什么是知识蒸馏,并想直接在 torchtune 中运行您自己的蒸馏,您可以跳转到 torchtune 中的 KD 配方 教程。

您将学到什么
  • 什么是 KD,以及它如何帮助提高模型性能

  • torchtune 中 KD 组件的概述

  • 如何使用 torchtune 从教师模型蒸馏到学生模型

  • 如何试验不同的 KD 配置

先决条件

什么是知识蒸馏?

知识蒸馏 是一种广泛使用的压缩技术,可将知识从较大的(教师)模型转移到较小的(学生)模型。较大的模型具有更多的参数和知识容量,然而,这种较大的容量在部署时也更耗费计算资源。知识蒸馏可用于将较大模型的知识压缩到较小的模型中。其理念是通过学习较大模型的输出来提高较小模型的性能。

知识蒸馏是如何工作的?

知识通过在迁移集上训练学生模型从教师模型转移到学生模型,其中学生模型被训练来模仿教师的令牌级概率分布。下图是 KD 工作原理的简化表示。

../_images/kd-simplified.png

总损失可以通过多种方式配置。torchtune 中的默认 KD 配置将交叉熵 (CE) 损失与前向 库尔贝克-莱布勒 (KL) 散度 损失相结合,这在标准 KD 方法中使用。前向 KL 散度旨在通过强制学生模型的分布与教师模型的所有分布对齐来最小化差异。然而,将学生模型分布与整个教师模型分布对齐可能并非有效,并且有多篇论文,例如 MiniLLMDistiLLMGeneralized KD,介绍了新的 KD 损失来解决这些局限性。对于本教程,让我们看一下前向 KL 散度损失的实现。

import torch
import torch.nn.functional as F

class ForwardKLLoss(torch.nn.Module):
  def __init__(self, ignore_index: int = -100)
    super().__init__()
    self.ignore_index = ignore_index

  def forward(self, student_logits, teacher_logits, labels) -> torch.Tensor:
    # Implementation from https://github.com/jongwooko/distillm
    # Computes the softmax of the teacher logits
    teacher_prob = F.softmax(teacher_logits, dim=-1, dtype=torch.float32)
    # Computes the student log softmax probabilities
    student_logprob = F.log_softmax(student_logits, dim=-1, dtype=torch.float32)
    # Computes the forward KL divergence
    prod_probs = teacher_prob * student_logprob
    # Compute the sum
    x = torch.sum(prod_probs, dim=-1).view(-1)
    # We don't want to include the ignore labels in the average
    mask = (labels != self.ignore_index).int()
    # Loss is averaged over non-ignored targets
    return -torch.sum(x * mask.view(-1), dim=0) / torch.sum(mask.view(-1), dim=0)

为了简化计算,省略了一些细节,但如果您想了解更多信息,可以在 ForwardKLLoss 中查看实现。默认情况下,KD 配置使用 ForwardKLWithChunkedOutputLoss 来减少内存。当前的实现仅支持具有相同输出 logits 形状和相同分词器的学生和教师模型。

torchtune 中的 KD 配方

借助 torchtune,我们可以轻松地将知识蒸馏应用于 Llama3 以及其他 LLM 模型系列。让我们看一下如何使用 torchtune 的 KD 配方 蒸馏模型。

首先,确保您已下载所有模型权重。在本示例中,我们将使用 Llama3.1-8B 作为教师模型,Llama3.2-1B 作为学生模型。

tune download meta-llama/Meta-Llama-3.1-8B-Instruct --output-dir /tmp/Meta-Llama-3.1-8B-Instruct --ignore-patterns "original/consolidated.00.pth" --hf_token <HF_TOKEN>

tune download meta-llama/Llama-3.2-1B-Instruct --output-dir /tmp/Llama-3.2-1B-Instruct --ignore-patterns "original/consolidated.00.pth" --hf_token <HF_TOKEN>

然后,我们将使用 LoRA 微调教师模型。根据我们的实验和之前的工作,我们发现当教师模型已经在目标数据集上进行了微调时,KD 效果更好。

tune run lora_finetune_single_device --config llama3_1/8B_lora_single_device

最后,我们可以运行以下命令,在单个 GPU 上将微调后的 8B 模型蒸馏到 1B 模型中。

tune run knowledge_distillation_single_device --config llama3_2/knowledge_distillation_single_device

消融研究

在前面的示例中,我们使用了 LoRA 微调的 8B 教师模型和基线 1B 学生模型,但我们可能希望尝试使用不同的配置和超参数。在本教程中,我们将使用 alpaca_cleaned_dataset 进行微调,并在 truthfulqa_mc2hellaswagcommonsense_qa 任务上通过 EleutherAI LM 评估工具 评估模型。让我们看看以下因素的影响:

  1. 使用微调的教师模型

  2. 使用微调的学生模型

  3. kd_ratio 和学习率的超参数调优

  4. 参数数量更接近的教师和学生模型

使用微调的教师模型

配置中的默认设置使用微调的教师模型。现在,让我们看一下首先不微调教师模型的效果。要更改教师模型,您可以修改配置中的 teacher_checkpointer

teacher_checkpointer:
  _component_: torchtune.training.FullModelHFCheckpointer
  checkpoint_dir: /tmp/Meta-Llama-3.1-8B-Instruct/
  checkpoint_files: [
      model-00001-of-00004.safetensors,
      model-00002-of-00004.safetensors,
      model-00003-of-00004.safetensors,
      model-00004-of-00004.safetensors
  ]

在下表中,我们可以看到,1B 模型的标准微调比基线 1B 模型实现了更好的准确率。通过使用微调后的 8B 教师模型,我们在 truthfulqa 上看到了相当的结果,并在 hellaswag 和 commonsense 上有所改进。当使用基线 8B 作为教师模型时,我们看到所有指标都有所改进,但低于其他配置。

../_images/kd-finetune-teacher.png

查看损失,使用基线 8B 作为教师模型会导致比使用微调后的教师模型更高的损失。KD 损失也保持相对恒定,这表明教师模型应该具有与迁移数据集相同的分布。

使用微调的学生模型

对于这些实验,让我们看一下当学生模型已经微调时 KD 的效果。在这些实验中,我们查看了基线和微调后的 8B 和 1B 模型的不同组合。要更改学生模型,您可以首先微调 1B 模型,然后在配置中修改学生模型检查点。

checkpointer:
   _component_: torchtune.training.FullModelHFCheckpointer
   checkpoint_dir: /tmp/Llama-3.2-1B-Instruct/
   checkpoint_files: [
     hf_model_0001_0.pt
   ]

使用微调后的学生模型进一步提高了 truthfulqa 的准确率,但 hellaswag 和 commonsense 的准确率有所下降。使用微调后的教师模型和基线学生模型在 hellaswag 和 commonsense 数据集上取得了最佳结果。根据这些发现,最佳配置将根据您要优化的评估数据集和指标而变化。

../_images/kd-finetune-student.png

根据损失图,无论学生模型是否经过微调,使用微调后的教师模型都会导致较低的损失。同样有趣的是,当使用微调后的学生模型时,类别损失开始增加。

超参数调优:学习率

默认情况下,配置中的学习率为 \(3e^{-4}\),这与 LoRA 配置相同。对于这些实验,我们将学习率从高达 \(1e^{-3}\) 更改为低至 \(1e^{-5}\)。要更改学习率,您可以简单地使用以下方式覆盖学习率参数

tune run knowledge_distillation_single_device --config llama3_2/knowledge_distillation_single_device optimizer.lr=1e-3

根据结果,最佳学习率会根据您要优化的指标而变化。

../_images/kd-hyperparam-lr.png

根据损失图,除了 \(1e^{-5}\) 之外,所有学习率都导致类似的损失,其中 \(1e^{-5}\) 具有更高的 KD 损失和类别损失。

超参数调优:KD 比率

在配置中,我们将 kd_ratio 设置为 0.5,这为类别损失和 KD 损失都赋予了相等的权重。在这些实验中,我们研究了不同 KD 比率的影响,其中 0 仅使用类别损失,而 1 仅使用 KD 损失。与更改学习率类似,KD 比率可以使用以下方式进行调整

tune run knowledge_distillation_single_device --config llama3_2/knowledge_distillation_single_device kd_ratio=0.25

总体而言,较高的 KD 比率的评估结果略好。

../_images/kd-hyperparam-kd-ratio.png

Qwen2 1.5B 到 0.5B

KD 配方也可以应用于不同的模型系列。在这里,我们研究了当教师模型和学生模型之间的参数数量更接近时 KD 的效果。对于此实验,我们使用了 Qwen2 1.5B 和 Qwen2 0.5B,其配置可以在 qwen2/knowledge_distillation_single_device 配置中找到。在这里,我们看到在 alpaca cleaned 数据集上进行训练仅提高了 truthful_qa 的性能,并降低了其他评估任务的指标。对于 truthful_qa,KD 将学生模型的性能提高了 5.8%,而微调将性能提高了 1.3%。

../_images/kd-qwen2-res.png

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得解答

查看资源