快捷方式

多模态数据集

多模态数据集包含多种数据模式,例如文本 + 图像,可用于训练基于 Transformer 的模型。torchtune 目前仅支持用于视觉语言模型 (VLM) 的多模态文本 + 图像聊天数据集。

在 torchtune 中使用多模态数据集进行微调的主要入口点是 multimodal_chat_dataset() 构建器。这使您可以直接从配置中指定遵循多模态聊天数据格式的本地或 Hugging Face 数据集,并在其上训练您的 VLM。

多模态数据集示例

这是一个用于视觉问答任务的多模态聊天数据集示例。请注意,文本中有一个占位符 "<image>",用于放置图像标记。在下面的示例中,它将被图像特殊标记 <|image|> 替换。

# data/my_data.json
[
    {
        "dialogue": [
            {
                "from": "human",
                "value": "<image>What time is it on the clock?",
            },
            {
                "from": "gpt",
                "value": "It is 10:00 AM.",
            },
        ],
        "image_path": "images/clock.jpg",
    },
    ...,
]
from torchtune.models.llama3_2_vision import llama3_2_vision_transform
from torchtune.datasets.multimodal import multimodal_chat_dataset

transform = Llama3VisionTransform(
    path="/tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model",
    prompt_template="torchtune.data.QuestionAnswerTemplate",
    max_seq_len=8192,
    image_size=560,
)
ds = multimodal_chat_dataset(
    model_transform=model_transform,
    source="json",
    data_files="data/my_data.json",
    column_map={
        "dialogue": "conversations",
        "image_path": "image",
    },
    image_dir="/home/user/dataset/",  # /home/user/dataset/images/clock.jpg
    image_tag="<image>",
    split="train",
)
tokenized_dict = ds[0]
print(transform.decode(tokenized_dict["tokens"], skip_special_tokens=False))
# '<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nQuestion:<|image|>What time is it on the clock?Answer:<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nIt is 10:00AM.<|eot_id|>'
print(tokenized_dict["encoder_input"]["images"][0].shape)  # (num_tiles, num_channels, tile_height, tile_width)
# torch.Size([4, 3, 224, 224])
# In config - model_transforms takes the place of the tokenizer
model_transform:
  _component_: torchtune.models.llama3_2_vision_transform
  path: /tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model
  prompt_template: torchtune.data.QuestionAnswerTemplate
  max_seq_len: 8192

dataset:
  _component_: torchtune.datasets.multimodal.multimodal_chat_dataset
  source: json
  data_files: data/my_data.json
  split: train
  column_map:
    dialogue: conversations
    image_path: image
  image_dir: /home/user/dataset/
  image_tag: "<image>"
  split: train

多模态数据集格式

目前,多模态数据集预计应遵循 "sharegpt" 聊天格式,其中图像路径位于一列,用户-助手对话位于另一列。

|  conversations                     | image        |
|------------------------------------|--------------|
| [{"from": "human", "value": "Q1"}, | images/1.jpg |
|  {"from": "gpt", "value": "A1"}]   |              |

例如,您可以查看 ShareGPT4V 数据集 的模式。

目前,multimodal_chat_dataset() 每个对话样本仅支持一个图像路径。

从 Hugging Face 加载多模态数据集

您只需要将数据集仓库名称传递给 source,然后将其传递给 Hugging Face 的 load_dataset。对于大多数数据集,您还需要通过 split 和/或 name 指定拆分和/或子集。

# In code
from torchtune.models.llama3_2_vision import llama3_2_vision_transform
from torchtune.datasets.multimodal import multimodal_chat_dataset

transform = llama3_2_vision_transform(
    path="/tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model",
    max_seq_len=8192,
    image_size=560,
)
ds = multimodal_chat_dataset(
    model_transform=model_transform,
    source="Lin-Chen/ShareGPT4V",
    split="train",
    name="ShareGPT4V",
    image_dir="/home/user/dataset/",
    image_tag="<image>",
)
# In config
model_transform:
  _component_: torchtune.models.llama3_2_vision.llama3_2_vision_transform
  path: /tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model
  max_seq_len: 8192
  image_size: 560

# Tokenizer is passed into the dataset in the recipe
dataset:
  _component_: torchtune.datasets.multimodal.multimodal_chat_dataset
  source: Lin-Chen/ShareGPT4V
  split: train
  name: ShareGPT4V
  image_dir: /home/user/dataset/
  image_tag: "<image>"

这将使用默认的列名“conversations”和“image”。要更改列名,请使用 column_map 参数(请参阅 重命名列)。

加载本地和远程多模态数据集

要通过 https 加载遵循指令格式的本地或远程数据集,您需要指定 sourcedata_filessplit 参数。有关加载本地或远程文件的更多详细信息,请参阅 Hugging Face 的 load_dataset 文档。请参阅上面的 多模态数据集示例

加载图像

在许多情况下,您的数据集将包含图像路径而不是原始图像本身。multimodal_chat_dataset() 将自动为您处理此问题,但如果您正在为自定义多模态数据集编写自定义消息转换(请参阅 自定义消息转换),您可以直接使用 load_image() 实用程序。

from torchtune.data import load_image
from pathlib import Path

sample = {
    "conversations": [
        {
            "from": "human",
            "value": "What time is it on the clock?",
        },
        {
            "from": "gpt",
            "value": "It is 10:00 AM.",
        },
    ],
    "image": "images/clock.jpg",
}
image_dir = "/home/user/dataset/"
pil_image = load_image(Path(image_dir) / Path(sample["image"]))
print(pil_image)
# <PIL.Image.Image>

然后,您可以将 PIL 图像直接添加到相关消息的内容中。在 Message 中,仅支持 PIL 图像作为图像内容,不支持图像路径或 URL。

from torchtune.data import Message

user_message = None
for msg in sample["conversations"]:
    if msg["from"] == "human":
        user_message = Message(
            role="user",
            content=[
                {"type": "image", "content": pil_image},
                {"type": "text", "content": msg["value"]},
            ]
        )
print(user_message.contains_media)
# True
print(user_message.get_media())
# [<PIL.Image.Image>]
print(user_message.text_content)
# What time is it on the clock?

如果数据集中的图像路径是相对路径,则可以使用 multimodal_chat_dataset() 中的 image_dir 参数在本地下载图像的完整路径前缀。

在文本中交错图像

torchtune 支持在文本的任何位置添加多个图像,只要您的模型支持即可。

import PIL
from torchtune.data import Message

image_dog = PIL.Image.new(mode="RGB", size=(4, 4))
image_cat = PIL.Image.new(mode="RGB", size=(4, 4))
image_bird = PIL.Image.new(mode="RGB", size=(4, 4))

user_message = Message(
    role="user",
    content=[
        {"type": "image", "content": image_dog},
        {"type": "text", "content": "This is an image of a dog. "},
        {"type": "image", "content": image_cat},
        {"type": "text", "content": "This is an image of a cat. "},
        {"type": "image", "content": image_bird},
        {"type": "text", "content": "This is a bird, the best pet of the three."},
    ]
)
print(user_message.contains_media)
# True
print(user_message.get_media())
# [<PIL.Image.Image>, <PIL.Image.Image>, <PIL.Image.Image>]
print(user_message.text_content)
# This is an image of a dog. This is an image of a cat. This is a bird, the best pet of the three.

您的数据集可能包含图像占位符标记,指示应在文本中的哪个位置引用图像。例如,请参阅 ShareGPT4V <https://hugging-face.cn/datasets/Lin-Chen/ShareGPT4V>,它使用 "<image>"。您可以使用实用程序 format_content_with_images() 轻松创建类似于上述的交错消息内容,该实用程序将图像占位符标记替换为传入的图像。

import PIL
from torchtune.data import Message, format_content_with_images

image_dog = PIL.Image.new(mode="RGB", size=(4, 4))
image_cat = PIL.Image.new(mode="RGB", size=(4, 4))
image_bird = PIL.Image.new(mode="RGB", size=(4, 4))

text = "[img]This is an image of a dog. [img]This is an image of a cat. [img]This is a bird, the best pet of the three."
user_message = Message(
    role="user",
    content=format_content_with_images(
        content=text,
        image_tag="[img]",
        images=[image_dog, image_cat, image_bird],
    ),
)
print(user_message.contains_media)
# True
print(user_message.get_media())
# [<PIL.Image.Image>,<PIL.Image.Image>, <PIL.Image.Image>]
print(user_message.text_content)
# This is an image of a dog. This is an image of a cat. This is a bird, the best pet of the three.

当您传入 image_tag 时,multimodal_chat_dataset() 会自动为您处理此问题。

内置多模态数据集

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得解答

查看资源