• 文档 >
  • 使用聊天数据微调 Llama3
快捷方式

使用聊天数据微调 Llama3

Llama3 Instruct 引入了用于使用聊天数据微调的新提示模板。在本教程中,我们将介绍快速开始准备自己的自定义聊天数据集以微调 Llama3 Instruct 所需的知识。

您将学到
  • Llama3 Instruct 格式与 Llama2 的区别

  • 关于提示模板和特殊标记

  • 如何使用您自己的聊天数据集微调 Llama3 Instruct

先决条件

注意

本教程需要 torchtune > 0.1.1 版本

从 Llama2 到 Llama3 的模板更改

Llama2 聊天模型在提示预训练模型时需要特定的模板。由于聊天模型是使用此提示模板预训练的,如果您想在模型上运行推理,则需要使用相同的模板才能在聊天数据上获得最佳性能。否则,模型将仅执行标准文本完成,这可能与您的预期用例一致,也可能不一致。

Llama2 官方提示模板指南 的 Llama2 聊天模型中,我们可以看到添加了特殊标签

<s>[INST] <<SYS>>
You are a helpful, respectful, and honest assistant.
<</SYS>>

Hi! I am a human. [/INST] Hello there! Nice to meet you! I'm Meta AI, your friendly AI assistant </s>

Llama3 Instruct 彻底改造 了 Llama2 的模板,以更好地支持多轮对话。Llama3 Instruct 格式中的相同文本将如下所示

<|begin_of_text|><|start_header_id|>system<|end_header_id|>

You are a helpful, respectful, and honest assistant.<|eot_id|><|start_header_id|>user<|end_header_id|>

Hi! I am a human.<|eot_id|><|start_header_id|>assistant<|end_header_id|>

Hello there! Nice to meet you! I'm Meta AI, your friendly AI assistant<|eot_id|>

这些标签完全不同,并且实际上它们与 Llama2 中的编码方式不同。让我们逐步完成使用 Llama2 模板和 Llama3 模板对示例进行标记的过程,以了解其工作原理。

注意

Llama3 Base 模型使用的是 与 Llama3 Instruct 不同的提示模板,因为它尚未经过指令微调,并且额外的特殊标记未经训练。如果您在没有微调的情况下在 Llama3 Base 模型上运行推理,我们建议使用基础模板以获得最佳性能。通常,对于指令和聊天数据,我们建议使用 Llama3 Instruct 及其提示模板。本教程的其余部分假设您正在使用 Llama3 Instruct。

标记提示模板和特殊标记

假设我有一个包含单一用户-助手回合以及系统提示的示例。

sample = [
    {
        "role": "system",
        "content": "You are a helpful, respectful, and honest assistant.",
    },
    {
        "role": "user",
        "content": "Who are the most influential hip-hop artists of all time?",
    },
    {
        "role": "assistant",
        "content": "Here is a list of some of the most influential hip-hop "
        "artists of all time: 2Pac, Rakim, N.W.A., Run-D.M.C., and Nas.",
    },
]

现在,让我们使用 Llama2ChatFormat 类对其进行格式化,并查看它如何被标记。Llama2ChatFormat 是 提示模板 的一个示例,它只是使用风格文本构建一个提示,以指示特定任务。

from torchtune.data import Llama2ChatFormat, Message

messages = [Message.from_dict(msg) for msg in sample]
formatted_messages = Llama2ChatFormat.format(messages)
print(formatted_messages)
# [
#     Message(
#         role='user',
#         content='[INST] <<SYS>>\nYou are a helpful, respectful, and honest assistant.\n<</SYS>>\n\nWho are the most influential hip-hop artists of all time? [/INST] ',
#         ...,
#     ),
#     Message(
#         role='assistant',
#         content='Here is a list of some of the most influential hip-hop artists of all time: 2Pac, Rakim, N.W.A., Run-D.M.C., and Nas.',
#         ...,
#     ),
# ]

Llama2 还使用了特殊标记,这些标记不在提示模板中。如果您查看我们的 Llama2ChatFormat 类,您会注意到我们没有包含 <s></s> 标记。这些是序列开始 (BOS) 和序列结束 (EOS) 标记,它们在标记器中的表示方式与提示模板中的其他内容不同。让我们使用 Llama2 使用的 llama2_tokenizer() 对此示例进行标记,以了解原因。

from torchtune.models.llama2 import llama2_tokenizer

tokenizer = llama2_tokenizer("/tmp/Llama-2-7b-hf/tokenizer.model")
user_message = formatted_messages[0].content
tokens = tokenizer.encode(user_message, add_bos=True, add_eos=True)
print(tokens)
# [1, 518, 25580, 29962, 3532, 14816, 29903, 6778, ..., 2]

我们在编码示例文本时添加了 BOS 和 EOS 标记。这显示为 ID 1 和 2。我们可以验证这些是我们的 BOS 和 EOS 标记。

print(tokenizer._spm_model.spm_model.piece_to_id("<s>"))
# 1
print(tokenizer._spm_model.spm_model.piece_to_id("</s>"))
# 2

BOS 和 EOS 标记是我们所谓的特殊标记,因为它们有自己的保留标记 ID。这意味着它们将在模型学习的嵌入表中索引到它们自己的单个向量。提示模板标签的其余部分,[INST]<<SYS>> 作为普通文本进行标记,而不是它们自己的 ID。

print(tokenizer.decode(518))
# '['
print(tokenizer.decode(25580))
# 'INST'
print(tokenizer.decode(29962))
# ']'
print(tokenizer.decode([3532, 14816, 29903, 6778]))
# '<<SYS>>'

重要的是要注意,您不应该在输入提示中手动放置保留的特殊标记,因为它们将被视为普通文本,而不是特殊标记。

print(tokenizer.encode("<s>", add_bos=False, add_eos=False))
# [529, 29879, 29958]

现在让我们看一下 Llama3 的格式,看看它与 Llama2 的标记方式有何不同。

from torchtune.models.llama3 import llama3_tokenizer

tokenizer = llama3_tokenizer("/tmp/Meta-Llama-3-8B/original/tokenizer.model")
messages = [Message.from_dict(msg) for msg in sample]
tokens, mask = tokenizer.tokenize_messages(messages)
print(tokenizer.decode(tokens))
# '<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful, respectful,
# and honest assistant.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWho
# are the most influential hip-hop artists of all time?<|eot_id|><|start_header_id|>
# assistant<|end_header_id|>\n\nHere is a list of some of the most influential hip-hop
# artists of all time: 2Pac, Rakim, N.W.A., Run-D.M.C., and Nas.<|eot_id|>'

注意

我们使用了 Llama3 的 tokenize_messages API,它与 encode 不同。它只是在对各个消息进行编码后,将所有特殊标记添加到正确的位置。

我们可以看到,标记器在没有我们指定提示模板的情况下处理了所有格式。事实证明,所有这些额外的标签都是特殊标记,我们不需要单独的提示模板。我们可以通过检查这些标签是否作为它们自己的标记 ID 进行编码来验证这一点。

print(tokenizer.special_tokens["<|begin_of_text|>"])
# 128000
print(tokenizer.special_tokens["<|eot_id|>"])
# 128009

最棒的是,所有这些特殊标记都是由标记器完全处理的。这意味着您不必担心搞乱任何所需的提示模板!

我什么时候应该使用提示模板?

是否使用提示模板取决于您希望的推理行为。如果您在基础模型上运行推理,并且它是在使用提示模板预训练的,或者您希望对微调后的模型进行初始化,以便在推理时针对特定任务期望特定提示结构,则应使用提示模板。

使用提示模板进行微调并不严格必要,但通常特定任务需要特定模板。例如,SummarizeTemplate 提供了一个轻量级结构来为您的微调后的模型进行初始化,以用于要求对文本进行摘要的提示。这将围绕用户消息进行包装,而助手消息保持不变。

f"Summarize this dialogue:\n{dialogue}\n---\nSummary:\n"

即使模型最初是使用 Llama2ChatFormat 预训练的,您也可以使用此模板对 Llama2 进行微调,只要模型在推理时看到的就是它。模型应该足够健壮,能够适应新的模板。

在自定义聊天数据集上进行微调

让我们通过尝试使用自定义聊天数据集对 Llama3-8B 指令模型进行微调来测试我们的理解。我们将逐步介绍如何设置数据,以便可以将其正确标记并馈送到模型中。

假设我们有一个本地数据集,它存储为 CSV 文件,其中包含来自在线论坛的问题和答案。我们如何将类似这样的内容转换成 Llama3 能够理解并正确标记的格式?

import pandas as pd

df = pd.read_csv('your_file.csv', nrows=1)
print("Header:", df.columns.tolist())
# ['input', 'output']
print("First row:", df.iloc[0].tolist())
# [
#     "How do GPS receivers communicate with satellites?",
#     "The first thing to know is the communication is one-way...",
# ]

Llama3 分词器类,Llama3Tokenizer,期望输入为 Message 格式。让我们快速编写一个函数,可以将 CSV 文件中的一行解析为 Message 数据类。该函数还需要包含一个 train_on_input 参数。

def message_converter(sample: Mapping[str, Any], train_on_input: bool) -> List[Message]:
    input_msg = sample["input"]
    output_msg = sample["output"]

    user_message = Message(
        role="user",
        content=input_msg,
        masked=not train_on_input,  # Mask if not training on prompt
    )
    assistant_message = Message(
        role="assistant",
        content=output_msg,
        masked=False,
    )
    # A single turn conversation
    messages = [user_message, assistant_message]

    return messages

由于我们正在对 Llama3 进行微调,分词器将为我们处理提示格式化。但是,如果我们正在微调需要模板的模型,例如 Mistral-7B 模型,它使用 MistralTokenizer,我们需要使用类似 MistralChatFormat 的聊天格式,根据他们的 推荐 来格式化所有消息。

现在让我们为数据集创建一个构建器函数,该函数加载本地文件,使用我们的函数将其转换为 Message 列表,并创建一个 ChatDataset 对象。

def custom_dataset(
    *,
    tokenizer: ModelTokenizer,
    max_seq_len: int = 2048,  # You can expose this if you want to experiment
) -> ChatDataset:

    return ChatDataset(
        tokenizer=tokenizer,
        # For local csv files, we specify "csv" as the source, just like in
        # load_dataset
        source="csv",
        # Default split of "train" is required for local files
        split="train",
        convert_to_messages=message_converter,
        # Llama3 does not need a chat format
        chat_format=None,
        max_seq_len=max_seq_len,
        # To load a local file we specify it as data_files just like in
        # load_dataset
        data_files="your_file.csv",
    )

注意

您可以将 load_dataset 的任何关键字参数传递到我们所有 Dataset 类中,它们会遵守这些参数。这对常见的参数非常有用,例如使用 split 指定数据分割或使用 name 指定配置。

现在我们准备开始微调了!我们将使用内置的 LoRA 单设备配方。使用 tune cp 命令获取 8B_lora_single_device.yaml 配置的副本,并更新它以使用您的新数据集。为您的项目创建一个新文件夹,并确保数据集构建器和消息转换器保存在该目录中,然后在配置中指定它。

dataset:
  _component_: path.to.my.custom_dataset
  max_seq_len: 2048

启动微调!

$ tune run lora_finetune_single_device --config custom_8B_lora_single_device.yaml epochs=15

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源