快捷方式

消息

消息是 torchtune 中的核心组件,它控制文本和多模态内容如何进行分词。它作为所有分词器和数据集 API 操作的通用接口。消息包含有关文本内容的信息,哪个角色正在发送文本内容,以及与模型分词器中的特殊标记相关的其他信息。有关消息的各个参数的更多信息,请参阅 Message 的 API 参考。

创建消息

消息可以通过标准类构造函数或直接从字典创建。

from torchtune.data import Message

msg = Message(
    role="user",
    content="Hello world!",
    masked=True,
    eot=True,
    ipython=False,
)
# This is identical
msg = Message.from_dict(
    {
        "role": "user",
        "content": "Hello world!",
        "masked": True,
        "eot": True,
        "ipython": False,
    },
)
print(msg.content)
# [{'type': 'text', 'content': 'Hello world!'}]

内容格式化为字典列表。这是因为消息还可以包含多模态内容,例如图像。

消息中的图像

对于多模态数据集,您需要将图像作为 Image 添加到相应的 Message。要将其添加到消息的开头,只需将其添加到内容列表的前面即可。

import PIL
from torchtune.data import Message

img_msg = Message(
    role="user",
    content=[
        {
            "type": "image",
            # Place your image here
            "content": PIL.Image.new(mode="RGB", size=(4, 4)),
        },
        {"type": "text", "content": "What's in this image?"},
    ],
)

这将指示模型分词器在何处添加图像特殊标记,并将由模型转换适当地处理。

在许多情况下,您将拥有图像路径而不是原始 Image。您可以使用 load_image() 实用程序处理本地路径和远程路径。

import PIL
from torchtune.data import Message, load_image

image_path = "path/to/image.jpg"
img_msg = Message(
    role="user",
    content=[
        {
            "type": "image",
            # Place your image here
            "content": load_image(image_path),
        },
        {"type": "text", "content": "What's in this image?"},
    ],
)

如果您的数据集包含图像标签或占位符文本以指示应在文本中的哪个位置插入图像,则可以使用 format_content_with_images() 将文本拆分为您可以传递到 Message 内容字段的正确内容列表。

import PIL
from torchtune.data import format_content_with_images

content = format_content_with_images(
    "<|image|>hello <|image|>world",
    image_tag="<|image|>",
    images=[PIL.Image.new(mode="RGB", size=(4, 4)), PIL.Image.new(mode="RGB", size=(4, 4))]
)
print(content)
# [
#     {"type": "image", "content": <PIL.Image.Image>},
#     {"type": "text", "content": "hello "},
#     {"type": "image", "content": <PIL.Image.Image>},
#     {"type": "text", "content": "world"}
# ]

消息转换

消息转换是将原始数据格式化为 torchtune Message 对象列表的便捷实用程序。

from torchtune.data import InputOutputToMessages

sample = {
    "input": "What is your name?",
    "output": "I am an AI assistant, I don't have a name."
}
transform = InputOutputToMessages()
output = transform(sample)
for message in output["messages"]:
    print(message.role, message.text_content)
# user What is your name?
# assistant I am an AI assistant, I don't have a name.

有关更多讨论,请参阅 消息转换

使用提示模板格式化消息

提示模板提供了一种将消息格式化为结构化文本模板的方法。您可以简单地对消息列表调用任何从 PromptTemplateInterface 继承的类,它会将相应的文本添加到内容列表中。

from torchtune.models.mistral import MistralChatTemplate
from torchtune.data import Message

msg = Message(
    role="user",
    content="Hello world!",
    masked=True,
    eot=True,
    ipython=False,
)
template = MistralChatTemplate()
templated_msg = template([msg])
print(templated_msg[0].content)
# [{'type': 'text', 'content': '[INST] '},
# {'type': 'text', 'content': 'Hello world!'},
# {'type': 'text', 'content': ' [/INST] '}]

访问消息中的文本内容

from torchtune.models.mistral import MistralChatTemplate
from torchtune.data import Message

msg = Message(
    role="user",
    content="Hello world!",
    masked=True,
    eot=True,
    ipython=False,
)
template = MistralChatTemplate()
templated_msg = template([msg])
print(templated_msg[0].text_content)
# [INST] Hello world! [/INST]

访问消息中的图像

from torchtune.data import Message
import PIL

msg = Message(
    role="user",
    content=[
        {
            "type": "image",
            # Place your image here
            "content": PIL.Image.new(mode="RGB", size=(4, 4)),
        },
        {"type": "text", "content": "What's in this image?"},
    ],
)
if msg.contains_media:
    print(msg.get_media())
# [<PIL.Image.Image image mode=RGB size=4x4 at 0x7F8D27E72740>]

对消息进行分词

所有模型分词器都具有 tokenize_messsages 方法,该方法将 Message 对象列表转换为标记 ID 和损失掩码。

from torchtune.models.mistral import mistral_tokenizer
from torchtune.data import Message

m_tokenizer = mistral_tokenizer(
    path="/tmp/Mistral-7B-v0.1/tokenizer.model",
    prompt_template="torchtune.models.mistral.MistralChatTemplate",
    max_seq_len=8192,
)
msgs = [
    Message(
        role="user",
        content="Hello world!",
        masked=True,
        eot=True,
        ipython=False,
    ),
    Message(
        role="assistant",
        content="Hi, I am an AI assistant.",
        masked=False,
        eot=True,
        ipython=False,
    )
]
tokens, mask = m_tokenizer.tokenize_messages(msgs)
print(tokens)
# [1, 733, 16289, 28793, 22557, 1526, 28808, 28705, 733, 28748, 16289, 28793, 15359, 28725, 315, 837, 396, 16107, 13892, 28723, 2]
print(mask)  # User message is masked from the loss
# [True, True, True, True, True, True, True, True, True, True, True, True, False, False, False, False, False, False, False, False, False]
print(m_tokenizer.decode(tokens))
# [INST] Hello world!  [/INST] Hi, I am an AI assistant.

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源