消息转换
消息转换执行从数据集中的原始样本字典到 torchtune 的 Message
结构的转换。一旦您的数据表示为消息,torchtune 将处理分词并为模型准备数据。
配置消息转换
我们的大多数内置消息转换都包含用于控制输入掩码 (train_on_input
)、添加系统提示 (new_system_prompt
) 和更改预期列名 (column_map
) 的参数。这些参数在我们的数据集构建器 instruct_dataset()
和 chat_dataset()
中公开,因此您不必担心消息转换本身,并且可以直接从配置中配置它。您可以查看 指令数据集示例 或 聊天数据集示例 以了解更多详细信息。
自定义消息转换
如果我们的内置消息转换不能很好地配置您的特定数据集,您可以创建自己的类以获得完全的灵活性。只需从 Transform
类继承,并在 __call__
方法中添加您的代码。
一个简单的虚构示例是从数据集中取一列作为用户消息,另一列作为模型响应。实际上,这与 InputOutputToMessages
非常相似。
from torchtune.modules.transforms import Transform
from torchtune.data import Message
from typing import Any, Mapping
from pprint import pprint
class MessageTransform(Transform):
def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]:
messages = [
Message(
role="user",
content=sample["input"],
masked=True,
eot=True,
),
Message(
role="assistant",
content=sample["output"],
masked=False,
eot=True,
),
]
return {"messages": messages}
input_sample = {"input": "hello world", "output": "bye world"}
transform = MessageTransform()
output_sample = transform(input_sample)
pprint(output_sample)
# {'messages': [Message(role='user', content=['hello world']),
# Message(role='assistant', content=['bye world'])]}
请参阅 创建消息 以了解有关如何操作 Message
对象的更多详细信息。
要将此用于您的数据集,您必须创建一个自定义数据集构建器,该构建器使用底层数据集类 SFTDataset
。
# In data/dataset.py
from torchtune.datasets import SFTDataset
def custom_dataset(tokenizer, **load_dataset_kwargs) -> SFTDataset:
message_transform = MyMessageTransform()
return SFTDataset(
source="json",
data_files="data/my_data.json",
split="train",
message_transform=message_transform,
model_transform=tokenizer,
**load_dataset_kwargs,
)
这可以直接从配置中使用。
dataset:
_component_: data.dataset.custom_dataset