快捷方式

配置微调数据集

本教程将指导您如何设置用于微调的数据集。

您将学到什么
  • 如何快速开始使用内置数据集

  • 如何从配置中配置现有的数据集类

  • 如何完全自定义您自己的数据集

先决条件

数据集是微调工作流程的核心组件,充当“方向盘”,引导 LLM 为特定用例生成内容。许多公开共享的开源数据集已成为微调 LLM 的热门选择,并作为训练模型的良好起点。我们支持几个广泛使用的数据集,以帮助您快速启动微调。让我们逐步了解如何设置一个用于微调的常见数据集。

您可以轻松地直接从配置文件中指定数据集

# Dataset
dataset:
  _component_: torchtune.datasets.alpaca_dataset

这将指示配方创建数据集对象,该对象迭代来自 HuggingFace 数据集上的 tatsu-lab/alpaca 的样本。

我们还公开了常见的旋钮,以根据您的需要调整数据集。例如,假设您希望在不更改批次大小的情况下减少每个批次的内存占用。您可以直接从配置中调整 max_seq_len 来实现这一点。

# Dataset
dataset:
  _component_: torchtune.datasets.alpaca_dataset
  # Original is 512
  max_seq_len: 256

自定义指令模板

为了在特定任务上微调 LLM,一种常见的方法是创建一个固定的指令模板,引导模型以特定目标生成输出。指令模板只是用于构建模型输入的文本格式。它与模型无关,并且像任何其他文本一样被正常分词,但它可以帮助调节模型更好地响应预期格式。例如,AlpacaInstructTemplate 以以下方式构建数据

"Below is an instruction that describes a task, paired with an input that provides further context. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"

以下是用 AlpacaInstructTemplate 格式化的示例

from torchtune.data import AlpacaInstructTemplate

sample = {
    "instruction": "Classify the following into animals, plants, and minerals",
    "input": "Oak tree, copper ore, elephant",
}
prompt = AlpacaInstructTemplate.format(sample)
print(prompt)
# Below is an instruction that describes a task, paired with an input that provides further context.
# Write a response that appropriately completes the request.
#
# ### Instruction:
# Classify the following into animals, plants, and minerals
#
# ### Input:
# Oak tree, copper ore, elephant
#
# ### Response:
#

我们为常见的任务(如摘要和语法校正)提供了 其他指令模板。如果您需要为自定义任务创建自己的指令模板,您可以创建自己的 InstructTemplate 类,并在配置中指向它。

dataset:
  _component_: torchtune.datasets.instruct_dataset
  source: mydataset/onthehub
  template: CustomTemplate
  train_on_input: True
  max_seq_len: 512

自定义聊天格式

聊天格式类似于指令模板,不同之处在于它们在消息列表中格式化系统、用户和助手消息(参见 ChatFormat),用于对话数据集。这些可以与指令数据集类似地配置。

dataset:
  _component_: torchtune.datasets.chat_dataset
  source: Open-Orca/SlimOrca-Dedup
  conversation_style: sharegpt
  chat_format: Llama2ChatFormat

以下是使用 Llama2ChatFormat 格式化消息的方式

from torchtune.data import Llama2ChatFormat, Message

messages = [
    Message(
        role="system",
        content="You are a helpful, respectful, and honest assistant.",
    ),
    Message(
        role="user",
        content="I am going to Paris, what should I see?",
    ),
    Message(
        role="assistant",
        content="Paris, the capital of France, is known for its stunning architecture..."
    ),
]
formatted_messages = Llama2ChatFormat.format(messages)
print(formatted_messages)
# [
#     Message(
#         role="user",
#         content="[INST] <<SYS>>\nYou are a helpful, respectful and honest assistant.\n<</SYS>>\n\n"
#         "I am going to Paris, what should I see? [/INST] ",
#     ),
#     Message(
#         role="assistant",
#         content="Paris, the capital of France, is known for its stunning architecture..."
#     ),
# ]

请注意,系统消息现在已包含在用户消息中。如果您创建自定义 ChatFormats,您还可以添加更高级的行为。

完全自定义的数据集

更高级的任务和数据集格式可能需要您创建自己的数据集类以获得更大的灵活性。让我们看一下 PreferenceDataset 的代码,它具有用于 RLHF 偏好数据的自定义功能,以了解您需要做什么。

如果您查看 PreferenceDataset 类的代码,您会注意到它与 InstructDataset 非常相似,只是对偏好数据中选定和拒绝的样本进行了一些调整。

chosen_message = [
    Message(role="user", content=prompt, masked=True),
    Message(role="assistant", content=transformed_sample[key_chosen]),
]
rejected_message = [
    Message(role="user", content=prompt, masked=True),
    Message(role="assistant", content=transformed_sample[key_rejected]),
]

chosen_input_ids, c_masks = self._tokenizer.tokenize_messages(
    chosen_message, self.max_seq_len
)
chosen_labels = list(
    np.where(c_masks, CROSS_ENTROPY_IGNORE_IDX, chosen_input_ids)
)

rejected_input_ids, r_masks = self._tokenizer.tokenize_messages(
    rejected_message, self.max_seq_len
)
rejected_labels = list(
    np.where(r_masks, CROSS_ENTROPY_IGNORE_IDX, rejected_input_ids)
)

如果任何现有的数据集类都不符合您的目的,您可以类似地使用其中一个作为起点,并添加您需要的功能。

为了能够从配置中使用您的自定义数据集,您需要创建一个构建器函数。这是 stack_exchanged_paired_dataset() 的构建器函数,它创建一个 PreferenceDataset,配置为使用来自 Hugging Face 的配对数据集。请注意,我们还必须添加一个自定义指令模板。

def stack_exchanged_paired_dataset(
    tokenizer: Tokenizer,
    max_seq_len: int = 1024,
) -> PreferenceDataset:
    return PreferenceDataset(
        tokenizer=tokenizer,
        source="lvwerra/stack-exchange-paired",
        template=StackExchangedPairedTemplate(),
        column_map={
            "prompt": "question",
            "chosen": "response_j",
            "rejected": "response_k",
        },
        max_seq_len=max_seq_len,
        split="train",
        data_dir="data/rl",
    )

现在我们可以轻松地从配置中指定我们的自定义数据集。

# This is how you would configure the Alpaca dataset using the builder
dataset:
  _component_: torchtune.datasets.stack_exchanged_paired_dataset
  max_seq_len: 512

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源